In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# --- Dataset Class ---
class PolymerDataset(Dataset):
    def __init__(self, input_dir, target_dir, size=(512, 512)):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.size = size
        self.filenames = sorted(os.listdir(input_dir))
        
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.filenames[idx])
        target_path = os.path.join(self.target_dir, self.filenames[idx])
        
        input_img = Image.open(input_path).convert('L')
        target_img = Image.open(target_path).convert('L')
        
        input_tensor = self.transform(input_img)
        target_tensor = self.transform(target_img)
        
        input_tensor = (input_tensor > 0.5).float() * 2 - 1
        target_tensor = (target_tensor > 0.5).float() * 2 - 1
        
        return input_tensor, target_tensor

In [3]:
# --- UNet Generator ---
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.conv_in = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False)
        
        # Downsampling
        self.down1 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128)
        )
        self.down2 = self._make_down_block(128, 256)
        self.down3 = self._make_down_block(256, 512)
        self.down4 = self._make_down_block(512, 512)
        self.down5 = self._make_down_block(512, 512)
        self.down6 = self._make_down_block(512, 512)

        # Middle
        self.middle = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512)
        )

        # Upsampling
        self.up1 = self._make_up_block(1024, 512, dropout=True)
        self.up2 = self._make_up_block(1024, 512, dropout=True)
        self.up3 = self._make_up_block(1024, 512, dropout=True)
        self.up4 = self._make_up_block(1024, 256)
        self.up5 = self._make_up_block(512, 128)
        self.up6 = self._make_up_block(256, 64)

        self.outermost = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _make_down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def _make_up_block(self, in_channels, out_channels, dropout=False):
        layers = [
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Encoder
        x0 = self.conv_in(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        # Middle
        x = self.middle(x6)

        # Decoder with skip connections
        x = self.up1(torch.cat([x, x6], dim=1))
        x = self.up2(torch.cat([x, x5], dim=1))
        x = self.up3(torch.cat([x, x4], dim=1))
        x = self.up4(torch.cat([x, x3], dim=1))
        x = self.up5(torch.cat([x, x2], dim=1))
        x = self.up6(torch.cat([x, x1], dim=1))
        
        return self.outermost(torch.cat([x, x0], dim=1))

In [4]:
# --- PatchGAN Discriminator ---
class PatchGAN(nn.Module):
    def __init__(self):
        super(PatchGAN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.net(x)

# --- Learning Rate Scheduler ---
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

# --- Visualization Function ---
def visualize_results(inputs, targets, outputs, epoch, save_dir="training_results"):
    os.makedirs(save_dir, exist_ok=True)
    
    # Denormalize images
    inputs = (inputs.cpu().numpy() + 1) / 2
    targets = (targets.cpu().numpy() + 1) / 2
    outputs = (outputs.cpu().numpy() + 1) / 2
    
    plt.figure(figsize=(15, 5))
    for i in range(min(3, inputs.shape[0])):  # Show max 3 examples
        plt.subplot(3, 3, i*3+1)
        plt.imshow(inputs[i][0], cmap='gray')
        plt.title("Input")
        plt.axis('off')
        
        plt.subplot(3, 3, i*3+2)
        plt.imshow(targets[i][0], cmap='gray')
        plt.title("Target")
        plt.axis('off')
        
        plt.subplot(3, 3, i*3+3)
        plt.imshow(outputs[i][0], cmap='gray')
        plt.title("Output")
        plt.axis('off')
    
    plt.suptitle(f"Epoch {epoch+1}")
    plt.savefig(f"{save_dir}/epoch_{epoch+1}.png")
    plt.close()

In [5]:
# Load dataset
input_dir = '/kaggle/input/dataset1/sliced images/input'
target_dir = '/kaggle/input/dataset1/sliced images/target'
    

# Split dataset
dataset = PolymerDataset(input_dir, target_dir)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
gen = torch.Generator()
gen.manual_seed(30)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [6]:
len(test_loader)

66

In [7]:
len(train_loader)

264

In [8]:
# --- Training Function with Visualization ---
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize models
    netG = UNet().to(device)
    netD = PatchGAN().to(device)
    
    # Loss functions
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    lambda_L1 = 20
    
    # Optimizers
    optimizer_G = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))
    
    # Learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G, lr_lambda=LambdaLR(100, 0, 50).step)
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D, lr_lambda=LambdaLR(100, 0, 50).step)

    
    # Create fixed batch for visualization
    fixed_inputs, fixed_targets = next(iter(test_loader))
    fixed_inputs = fixed_inputs.to(device)
    
    # History tracking
    history = {
        'epoch': [],
        'G_loss': [],
        'D_loss': [],
        'D_real': [],
        'D_fake': [],
        'D_x': [],       # D(x) - average output on real images
        'D_G_z': [],     # D(G(z)) - average output on fake images
        'val_loss': []
    }
    
    # Training loop
    for epoch in range(100):
        netG.train()
        netD.train()
        
        # Initialize accumulators
        running_loss_G = 0.0
        running_loss_D = 0.0
        running_loss_D_real = 0.0
        running_loss_D_fake = 0.0
        running_D_x = 0.0
        running_D_G_z = 0.0
        d_steps = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/100')
        for i, (input_imgs, target_imgs) in enumerate(progress_bar):
            input_imgs, target_imgs = input_imgs.to(device), target_imgs.to(device)
            
            # Train Generator
            optimizer_G.zero_grad()
            fake_B = netG(input_imgs)
            fake_AB = torch.cat((input_imgs, fake_B), 1)
            
            pred_fake = netD(fake_AB)
            loss_G_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
            loss_G_L1 = criterion_L1(fake_B, target_imgs) * lambda_L1
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            optimizer_G.step()
            
            running_loss_G += loss_G.item()
            running_D_G_z += torch.sigmoid(pred_fake).mean().item()
            
            # Train Discriminator (every 2nd step)
            if i % 2 == 0:
                optimizer_D.zero_grad()
                
                # Real images
                real_AB = torch.cat((input_imgs, target_imgs), 1)
                pred_real = netD(real_AB)
                real_labels = torch.rand_like(pred_real)*0.1 + 0.9
                loss_D_real = criterion_GAN(pred_real, real_labels)
                
                # Fake images
                fake_AB = torch.cat((input_imgs, fake_B.detach()), 1)
                pred_fake = netD(fake_AB)
                fake_labels = torch.rand_like(pred_fake)*0.1
                loss_D_fake = criterion_GAN(pred_fake, fake_labels)
                
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                optimizer_D.step()
                
                running_loss_D += loss_D.item()
                running_loss_D_real += loss_D_real.item()
                running_loss_D_fake += loss_D_fake.item()
                running_D_x += torch.sigmoid(pred_real).mean().item()
                d_steps += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'G': f'{running_loss_G/(i+1):.4f}',
                'D': f'{running_loss_D/d_steps:.4f}' if d_steps > 0 else 'skipped',
                'D(x)': f'{running_D_x/d_steps:.3f}' if d_steps > 0 else '-',
                'D(G(z))': f'{running_D_G_z/(i+1):.3f}'
            })
            
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()
        
        # Validation
        netG.eval()
        val_loss = 0
        with torch.no_grad():
            for input_imgs, target_imgs in test_loader:
                input_imgs, target_imgs = input_imgs.to(device), target_imgs.to(device)
                fake_B = netG(input_imgs)
                val_loss += criterion_L1(fake_B, target_imgs).item()
            
            # Generate visualization
            fixed_outputs = netG(fixed_inputs)
            visualize_results(fixed_inputs[:3], fixed_targets[:3].to(device), fixed_outputs[:3], epoch)
        
        # Calculate averages
        avg_loss_G = running_loss_G / len(train_loader)
        avg_loss_D = running_loss_D / d_steps if d_steps > 0 else 0
        avg_loss_D_real = running_loss_D_real / d_steps if d_steps > 0 else 0
        avg_loss_D_fake = running_loss_D_fake / d_steps if d_steps > 0 else 0
        avg_D_x = running_D_x / d_steps if d_steps > 0 else 0
        avg_D_G_z = running_D_G_z / len(train_loader)
        avg_val_loss = val_loss / len(test_loader)
        
        # Store history
        history['epoch'].append(epoch+1)
        history['G_loss'].append(avg_loss_G)
        history['D_loss'].append(avg_loss_D)
        history['D_real'].append(avg_loss_D_real)
        history['D_fake'].append(avg_loss_D_fake)
        history['D_x'].append(avg_D_x)
        history['D_G_z'].append(avg_D_G_z)
        history['val_loss'].append(avg_val_loss)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1} Summary:')
        print(f'G_loss: {avg_loss_G:.4f} | D_loss: {avg_loss_D:.4f}')
        print(f'D_real: {avg_loss_D_real:.4f} | D_fake: {avg_loss_D_fake:.4f}')
        print(f'D(x): {avg_D_x:.3f} | D(G(z)): {avg_D_G_z:.3f}')
        print(f'Val Loss: {avg_val_loss:.4f}')
        
        # Save models
        if (epoch + 1) % 10 == 0:
            torch.save(netG.state_dict(), f'generator_epoch_{epoch+1}.pth')
            torch.save(netD.state_dict(), f'discriminator_epoch_{epoch+1}.pth')
    
    # Save final models
    torch.save(netG.state_dict(), 'generator_final.pth')
    torch.save(netD.state_dict(), 'discriminator_final.pth')
    
    # Plot training history
    plot_training_history(history)

def plot_training_history(history):
    plt.figure(figsize=(15, 10))
    
    # Loss plot
    plt.subplot(2, 2, 1)
    plt.plot(history['epoch'], history['G_loss'], label='Generator Loss')
    plt.plot(history['epoch'], history['D_loss'], label='Discriminator Loss')
    plt.plot(history['epoch'], history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    
    # D(x) and D(G(z)) plot
    plt.subplot(2, 2, 2)
    plt.plot(history['epoch'], history['D_x'], label='D(x) - Real')
    plt.plot(history['epoch'], history['D_G_z'], label='D(G(z)) - Fake')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Discriminator Output Scores')
    plt.legend()
    
    # D components plot
    plt.subplot(2, 2, 3)
    plt.plot(history['epoch'], history['D_real'], label='D Real Loss')
    plt.plot(history['epoch'], history['D_fake'], label='D Fake Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Discriminator Component Losses')
    plt.legend()
    
    # Save plots
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

if __name__ == "__main__":
    train()

Epoch 1/100: 100%|██████████| 264/264 [01:33<00:00,  2.83it/s, G=7.3922, D=0.6139, D(x)=0.563, D(G(z))=0.432]



Epoch 1 Summary:
G_loss: 7.3922 | D_loss: 0.6139
D_real: 0.6194 | D_fake: 0.6084
D(x): 0.563 | D(G(z)): 0.432
Val Loss: 0.3928


Epoch 2/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=8.0334, D=0.4943, D(x)=0.666, D(G(z))=0.336]



Epoch 2 Summary:
G_loss: 8.0334 | D_loss: 0.4943
D_real: 0.5096 | D_fake: 0.4791
D(x): 0.666 | D(G(z)): 0.336
Val Loss: 0.3720


Epoch 3/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=8.1279, D=0.5385, D(x)=0.645, D(G(z))=0.353]



Epoch 3 Summary:
G_loss: 8.1279 | D_loss: 0.5385
D_real: 0.5733 | D_fake: 0.5037
D(x): 0.645 | D(G(z)): 0.353
Val Loss: 0.3970


Epoch 4/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=7.9506, D=0.5647, D(x)=0.621, D(G(z))=0.374]



Epoch 4 Summary:
G_loss: 7.9506 | D_loss: 0.5647
D_real: 0.6009 | D_fake: 0.5284
D(x): 0.621 | D(G(z)): 0.374
Val Loss: 0.3347


Epoch 5/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=7.8266, D=0.5753, D(x)=0.609, D(G(z))=0.389]



Epoch 5 Summary:
G_loss: 7.8266 | D_loss: 0.5753
D_real: 0.6029 | D_fake: 0.5478
D(x): 0.609 | D(G(z)): 0.389
Val Loss: 0.3332


Epoch 6/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=7.7062, D=0.5902, D(x)=0.596, D(G(z))=0.399]



Epoch 6 Summary:
G_loss: 7.7062 | D_loss: 0.5902
D_real: 0.6072 | D_fake: 0.5732
D(x): 0.596 | D(G(z)): 0.399
Val Loss: 0.3183


Epoch 7/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=7.6210, D=0.5883, D(x)=0.595, D(G(z))=0.401]



Epoch 7 Summary:
G_loss: 7.6210 | D_loss: 0.5883
D_real: 0.6116 | D_fake: 0.5651
D(x): 0.595 | D(G(z)): 0.401
Val Loss: 0.3053


Epoch 8/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=7.5021, D=0.5953, D(x)=0.591, D(G(z))=0.416]



Epoch 8 Summary:
G_loss: 7.5021 | D_loss: 0.5953
D_real: 0.6052 | D_fake: 0.5853
D(x): 0.591 | D(G(z)): 0.416
Val Loss: 0.3551


Epoch 9/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=7.4981, D=0.5936, D(x)=0.591, D(G(z))=0.408]



Epoch 9 Summary:
G_loss: 7.4981 | D_loss: 0.5936
D_real: 0.6121 | D_fake: 0.5750
D(x): 0.591 | D(G(z)): 0.408
Val Loss: 0.3219


Epoch 10/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=7.3453, D=0.5967, D(x)=0.587, D(G(z))=0.418]



Epoch 10 Summary:
G_loss: 7.3453 | D_loss: 0.5967
D_real: 0.6032 | D_fake: 0.5901
D(x): 0.587 | D(G(z)): 0.418
Val Loss: 0.3438


Epoch 11/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=7.4410, D=0.6017, D(x)=0.588, D(G(z))=0.400]



Epoch 11 Summary:
G_loss: 7.4410 | D_loss: 0.6017
D_real: 0.6200 | D_fake: 0.5834
D(x): 0.588 | D(G(z)): 0.400
Val Loss: 0.3087


Epoch 12/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=7.2515, D=0.6000, D(x)=0.583, D(G(z))=0.422]



Epoch 12 Summary:
G_loss: 7.2515 | D_loss: 0.6000
D_real: 0.6059 | D_fake: 0.5940
D(x): 0.583 | D(G(z)): 0.422
Val Loss: 0.3079


Epoch 13/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=7.2081, D=0.6064, D(x)=0.578, D(G(z))=0.420]



Epoch 13 Summary:
G_loss: 7.2081 | D_loss: 0.6064
D_real: 0.6154 | D_fake: 0.5974
D(x): 0.578 | D(G(z)): 0.420
Val Loss: 0.3370


Epoch 14/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=7.1540, D=0.6044, D(x)=0.576, D(G(z))=0.422]



Epoch 14 Summary:
G_loss: 7.1540 | D_loss: 0.6044
D_real: 0.6109 | D_fake: 0.5979
D(x): 0.576 | D(G(z)): 0.422
Val Loss: 0.3332


Epoch 15/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=7.1704, D=0.6042, D(x)=0.577, D(G(z))=0.421]



Epoch 15 Summary:
G_loss: 7.1704 | D_loss: 0.6042
D_real: 0.6140 | D_fake: 0.5945
D(x): 0.577 | D(G(z)): 0.421
Val Loss: 0.3138


Epoch 16/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=7.0822, D=0.6087, D(x)=0.578, D(G(z))=0.418]



Epoch 16 Summary:
G_loss: 7.0822 | D_loss: 0.6087
D_real: 0.6151 | D_fake: 0.6023
D(x): 0.578 | D(G(z)): 0.418
Val Loss: 0.3134


Epoch 17/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.9573, D=0.6006, D(x)=0.577, D(G(z))=0.424]



Epoch 17 Summary:
G_loss: 6.9573 | D_loss: 0.6006
D_real: 0.6074 | D_fake: 0.5938
D(x): 0.577 | D(G(z)): 0.424
Val Loss: 0.3260


Epoch 18/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.9196, D=0.6057, D(x)=0.576, D(G(z))=0.436]



Epoch 18 Summary:
G_loss: 6.9196 | D_loss: 0.6057
D_real: 0.6062 | D_fake: 0.6052
D(x): 0.576 | D(G(z)): 0.436
Val Loss: 0.3207


Epoch 19/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.8960, D=0.6173, D(x)=0.575, D(G(z))=0.430]



Epoch 19 Summary:
G_loss: 6.8960 | D_loss: 0.6173
D_real: 0.6229 | D_fake: 0.6118
D(x): 0.575 | D(G(z)): 0.430
Val Loss: 0.3378


Epoch 20/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.7901, D=0.6049, D(x)=0.576, D(G(z))=0.423]



Epoch 20 Summary:
G_loss: 6.7901 | D_loss: 0.6049
D_real: 0.6085 | D_fake: 0.6013
D(x): 0.576 | D(G(z)): 0.423
Val Loss: 0.3295


Epoch 21/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.7477, D=0.6049, D(x)=0.577, D(G(z))=0.422]



Epoch 21 Summary:
G_loss: 6.7477 | D_loss: 0.6049
D_real: 0.6123 | D_fake: 0.5975
D(x): 0.577 | D(G(z)): 0.422
Val Loss: 0.3349


Epoch 22/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.6178, D=0.6062, D(x)=0.576, D(G(z))=0.421]



Epoch 22 Summary:
G_loss: 6.6178 | D_loss: 0.6062
D_real: 0.6087 | D_fake: 0.6036
D(x): 0.576 | D(G(z)): 0.421
Val Loss: 0.3221


Epoch 23/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.5763, D=0.6135, D(x)=0.576, D(G(z))=0.427]



Epoch 23 Summary:
G_loss: 6.5763 | D_loss: 0.6135
D_real: 0.6123 | D_fake: 0.6148
D(x): 0.576 | D(G(z)): 0.427
Val Loss: 0.3264


Epoch 24/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.4958, D=0.6025, D(x)=0.578, D(G(z))=0.419]



Epoch 24 Summary:
G_loss: 6.4958 | D_loss: 0.6025
D_real: 0.6090 | D_fake: 0.5959
D(x): 0.578 | D(G(z)): 0.419
Val Loss: 0.3168


Epoch 25/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.6550, D=0.6058, D(x)=0.579, D(G(z))=0.418]



Epoch 25 Summary:
G_loss: 6.6550 | D_loss: 0.6058
D_real: 0.6155 | D_fake: 0.5961
D(x): 0.579 | D(G(z)): 0.418
Val Loss: 0.3208


Epoch 26/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.4146, D=0.6072, D(x)=0.580, D(G(z))=0.422]



Epoch 26 Summary:
G_loss: 6.4146 | D_loss: 0.6072
D_real: 0.6083 | D_fake: 0.6060
D(x): 0.580 | D(G(z)): 0.422
Val Loss: 0.3235


Epoch 27/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3497, D=0.6021, D(x)=0.573, D(G(z))=0.423]



Epoch 27 Summary:
G_loss: 6.3497 | D_loss: 0.6021
D_real: 0.6069 | D_fake: 0.5973
D(x): 0.573 | D(G(z)): 0.423
Val Loss: 0.3417


Epoch 28/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3763, D=0.5970, D(x)=0.581, D(G(z))=0.419]



Epoch 28 Summary:
G_loss: 6.3763 | D_loss: 0.5970
D_real: 0.6025 | D_fake: 0.5916
D(x): 0.581 | D(G(z)): 0.419
Val Loss: 0.3271


Epoch 29/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3688, D=0.6007, D(x)=0.576, D(G(z))=0.417]



Epoch 29 Summary:
G_loss: 6.3688 | D_loss: 0.6007
D_real: 0.6040 | D_fake: 0.5975
D(x): 0.576 | D(G(z)): 0.417
Val Loss: 0.3317


Epoch 30/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3337, D=0.5932, D(x)=0.585, D(G(z))=0.418]



Epoch 30 Summary:
G_loss: 6.3337 | D_loss: 0.5932
D_real: 0.5952 | D_fake: 0.5912
D(x): 0.585 | D(G(z)): 0.418
Val Loss: 0.3330


Epoch 31/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.3220, D=0.6000, D(x)=0.580, D(G(z))=0.416]



Epoch 31 Summary:
G_loss: 6.3220 | D_loss: 0.6000
D_real: 0.6056 | D_fake: 0.5945
D(x): 0.580 | D(G(z)): 0.416
Val Loss: 0.3358


Epoch 32/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3405, D=0.5971, D(x)=0.582, D(G(z))=0.413]



Epoch 32 Summary:
G_loss: 6.3405 | D_loss: 0.5971
D_real: 0.6009 | D_fake: 0.5933
D(x): 0.582 | D(G(z)): 0.413
Val Loss: 0.3381


Epoch 33/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3113, D=0.6021, D(x)=0.580, D(G(z))=0.416]



Epoch 33 Summary:
G_loss: 6.3113 | D_loss: 0.6021
D_real: 0.6088 | D_fake: 0.5953
D(x): 0.580 | D(G(z)): 0.416
Val Loss: 0.3189


Epoch 34/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3117, D=0.5965, D(x)=0.584, D(G(z))=0.409]



Epoch 34 Summary:
G_loss: 6.3117 | D_loss: 0.5965
D_real: 0.5965 | D_fake: 0.5965
D(x): 0.584 | D(G(z)): 0.409
Val Loss: 0.3326


Epoch 35/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2873, D=0.5926, D(x)=0.585, D(G(z))=0.419]



Epoch 35 Summary:
G_loss: 6.2873 | D_loss: 0.5926
D_real: 0.5944 | D_fake: 0.5908
D(x): 0.585 | D(G(z)): 0.419
Val Loss: 0.3214


Epoch 36/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3069, D=0.5963, D(x)=0.583, D(G(z))=0.410]



Epoch 36 Summary:
G_loss: 6.3069 | D_loss: 0.5963
D_real: 0.6009 | D_fake: 0.5916
D(x): 0.583 | D(G(z)): 0.410
Val Loss: 0.3271


Epoch 37/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.2145, D=0.5942, D(x)=0.586, D(G(z))=0.412]



Epoch 37 Summary:
G_loss: 6.2145 | D_loss: 0.5942
D_real: 0.5995 | D_fake: 0.5889
D(x): 0.586 | D(G(z)): 0.412
Val Loss: 0.3267


Epoch 38/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2565, D=0.5940, D(x)=0.583, D(G(z))=0.417]



Epoch 38 Summary:
G_loss: 6.2565 | D_loss: 0.5940
D_real: 0.5927 | D_fake: 0.5954
D(x): 0.583 | D(G(z)): 0.417
Val Loss: 0.3422


Epoch 39/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2398, D=0.5967, D(x)=0.583, D(G(z))=0.409]



Epoch 39 Summary:
G_loss: 6.2398 | D_loss: 0.5967
D_real: 0.6011 | D_fake: 0.5922
D(x): 0.583 | D(G(z)): 0.409
Val Loss: 0.3380


Epoch 40/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2315, D=0.5947, D(x)=0.584, D(G(z))=0.413]



Epoch 40 Summary:
G_loss: 6.2315 | D_loss: 0.5947
D_real: 0.5960 | D_fake: 0.5935
D(x): 0.584 | D(G(z)): 0.413
Val Loss: 0.3350


Epoch 41/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2536, D=0.5952, D(x)=0.584, D(G(z))=0.408]



Epoch 41 Summary:
G_loss: 6.2536 | D_loss: 0.5952
D_real: 0.5952 | D_fake: 0.5951
D(x): 0.584 | D(G(z)): 0.408
Val Loss: 0.3548


Epoch 42/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2032, D=0.5937, D(x)=0.584, D(G(z))=0.412]



Epoch 42 Summary:
G_loss: 6.2032 | D_loss: 0.5937
D_real: 0.5969 | D_fake: 0.5904
D(x): 0.584 | D(G(z)): 0.412
Val Loss: 0.3374


Epoch 43/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.2305, D=0.6000, D(x)=0.585, D(G(z))=0.410]



Epoch 43 Summary:
G_loss: 6.2305 | D_loss: 0.6000
D_real: 0.6031 | D_fake: 0.5968
D(x): 0.585 | D(G(z)): 0.410
Val Loss: 0.3251


Epoch 44/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2941, D=0.5933, D(x)=0.587, D(G(z))=0.413]



Epoch 44 Summary:
G_loss: 6.2941 | D_loss: 0.5933
D_real: 0.5980 | D_fake: 0.5887
D(x): 0.587 | D(G(z)): 0.413
Val Loss: 0.3409


Epoch 45/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2357, D=0.5882, D(x)=0.589, D(G(z))=0.410]



Epoch 45 Summary:
G_loss: 6.2357 | D_loss: 0.5882
D_real: 0.5882 | D_fake: 0.5882
D(x): 0.589 | D(G(z)): 0.410
Val Loss: 0.3211


Epoch 46/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2367, D=0.5778, D(x)=0.591, D(G(z))=0.407]



Epoch 46 Summary:
G_loss: 6.2367 | D_loss: 0.5778
D_real: 0.5808 | D_fake: 0.5749
D(x): 0.591 | D(G(z)): 0.407
Val Loss: 0.3287


Epoch 47/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2617, D=0.5981, D(x)=0.589, D(G(z))=0.409]



Epoch 47 Summary:
G_loss: 6.2617 | D_loss: 0.5981
D_real: 0.6045 | D_fake: 0.5918
D(x): 0.589 | D(G(z)): 0.409
Val Loss: 0.3320


Epoch 48/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.2181, D=0.5827, D(x)=0.592, D(G(z))=0.410]



Epoch 48 Summary:
G_loss: 6.2181 | D_loss: 0.5827
D_real: 0.5861 | D_fake: 0.5793
D(x): 0.592 | D(G(z)): 0.410
Val Loss: 0.3291


Epoch 49/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2538, D=0.5832, D(x)=0.590, D(G(z))=0.401]



Epoch 49 Summary:
G_loss: 6.2538 | D_loss: 0.5832
D_real: 0.5891 | D_fake: 0.5772
D(x): 0.590 | D(G(z)): 0.401
Val Loss: 0.3562


Epoch 50/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2467, D=0.5863, D(x)=0.592, D(G(z))=0.410]



Epoch 50 Summary:
G_loss: 6.2467 | D_loss: 0.5863
D_real: 0.5858 | D_fake: 0.5869
D(x): 0.592 | D(G(z)): 0.410
Val Loss: 0.3260


Epoch 51/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3212, D=0.6145, D(x)=0.587, D(G(z))=0.407]



Epoch 51 Summary:
G_loss: 6.3212 | D_loss: 0.6145
D_real: 0.6177 | D_fake: 0.6113
D(x): 0.587 | D(G(z)): 0.407
Val Loss: 0.3196


Epoch 52/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.1859, D=0.5764, D(x)=0.590, D(G(z))=0.409]



Epoch 52 Summary:
G_loss: 6.1859 | D_loss: 0.5764
D_real: 0.5771 | D_fake: 0.5756
D(x): 0.590 | D(G(z)): 0.409
Val Loss: 0.3270


Epoch 53/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2601, D=0.5791, D(x)=0.592, D(G(z))=0.402]



Epoch 53 Summary:
G_loss: 6.2601 | D_loss: 0.5791
D_real: 0.5824 | D_fake: 0.5759
D(x): 0.592 | D(G(z)): 0.402
Val Loss: 0.3487


Epoch 54/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2590, D=0.5788, D(x)=0.595, D(G(z))=0.402]



Epoch 54 Summary:
G_loss: 6.2590 | D_loss: 0.5788
D_real: 0.5799 | D_fake: 0.5777
D(x): 0.595 | D(G(z)): 0.402
Val Loss: 0.3506


Epoch 55/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2598, D=0.5824, D(x)=0.597, D(G(z))=0.402]



Epoch 55 Summary:
G_loss: 6.2598 | D_loss: 0.5824
D_real: 0.5829 | D_fake: 0.5818
D(x): 0.597 | D(G(z)): 0.402
Val Loss: 0.3446


Epoch 56/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2720, D=0.5767, D(x)=0.593, D(G(z))=0.407]



Epoch 56 Summary:
G_loss: 6.2720 | D_loss: 0.5767
D_real: 0.5782 | D_fake: 0.5751
D(x): 0.593 | D(G(z)): 0.407
Val Loss: 0.3276


Epoch 57/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2741, D=0.5734, D(x)=0.596, D(G(z))=0.399]



Epoch 57 Summary:
G_loss: 6.2741 | D_loss: 0.5734
D_real: 0.5743 | D_fake: 0.5725
D(x): 0.596 | D(G(z)): 0.399
Val Loss: 0.3346


Epoch 58/100: 100%|██████████| 264/264 [01:18<00:00,  3.36it/s, G=6.2884, D=0.5744, D(x)=0.597, D(G(z))=0.401]



Epoch 58 Summary:
G_loss: 6.2884 | D_loss: 0.5744
D_real: 0.5778 | D_fake: 0.5709
D(x): 0.597 | D(G(z)): 0.401
Val Loss: 0.3304


Epoch 59/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3129, D=0.5729, D(x)=0.597, D(G(z))=0.396]



Epoch 59 Summary:
G_loss: 6.3129 | D_loss: 0.5729
D_real: 0.5770 | D_fake: 0.5688
D(x): 0.597 | D(G(z)): 0.396
Val Loss: 0.3260


Epoch 60/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3094, D=0.5665, D(x)=0.602, D(G(z))=0.395]



Epoch 60 Summary:
G_loss: 6.3094 | D_loss: 0.5665
D_real: 0.5678 | D_fake: 0.5653
D(x): 0.602 | D(G(z)): 0.395
Val Loss: 0.3391


Epoch 61/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3023, D=0.5622, D(x)=0.603, D(G(z))=0.390]



Epoch 61 Summary:
G_loss: 6.3023 | D_loss: 0.5622
D_real: 0.5632 | D_fake: 0.5613
D(x): 0.603 | D(G(z)): 0.390
Val Loss: 0.3394


Epoch 62/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2990, D=0.5640, D(x)=0.605, D(G(z))=0.394]



Epoch 62 Summary:
G_loss: 6.2990 | D_loss: 0.5640
D_real: 0.5622 | D_fake: 0.5657
D(x): 0.605 | D(G(z)): 0.394
Val Loss: 0.3376


Epoch 63/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3059, D=0.5741, D(x)=0.603, D(G(z))=0.391]



Epoch 63 Summary:
G_loss: 6.3059 | D_loss: 0.5741
D_real: 0.5755 | D_fake: 0.5726
D(x): 0.603 | D(G(z)): 0.391
Val Loss: 0.3386


Epoch 64/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2897, D=0.5486, D(x)=0.610, D(G(z))=0.389]



Epoch 64 Summary:
G_loss: 6.2897 | D_loss: 0.5486
D_real: 0.5503 | D_fake: 0.5469
D(x): 0.610 | D(G(z)): 0.389
Val Loss: 0.3336


Epoch 65/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3116, D=0.5567, D(x)=0.607, D(G(z))=0.388]



Epoch 65 Summary:
G_loss: 6.3116 | D_loss: 0.5567
D_real: 0.5587 | D_fake: 0.5546
D(x): 0.607 | D(G(z)): 0.388
Val Loss: 0.3471


Epoch 66/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2945, D=0.5630, D(x)=0.607, D(G(z))=0.389]



Epoch 66 Summary:
G_loss: 6.2945 | D_loss: 0.5630
D_real: 0.5637 | D_fake: 0.5623
D(x): 0.607 | D(G(z)): 0.389
Val Loss: 0.3344


Epoch 67/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3081, D=0.5439, D(x)=0.614, D(G(z))=0.382]



Epoch 67 Summary:
G_loss: 6.3081 | D_loss: 0.5439
D_real: 0.5462 | D_fake: 0.5416
D(x): 0.614 | D(G(z)): 0.382
Val Loss: 0.3266


Epoch 68/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3353, D=0.5454, D(x)=0.616, D(G(z))=0.381]



Epoch 68 Summary:
G_loss: 6.3353 | D_loss: 0.5454
D_real: 0.5452 | D_fake: 0.5456
D(x): 0.616 | D(G(z)): 0.381
Val Loss: 0.3360


Epoch 69/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3204, D=0.5515, D(x)=0.614, D(G(z))=0.383]



Epoch 69 Summary:
G_loss: 6.3204 | D_loss: 0.5515
D_real: 0.5569 | D_fake: 0.5461
D(x): 0.614 | D(G(z)): 0.383
Val Loss: 0.3290


Epoch 70/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3055, D=0.5338, D(x)=0.621, D(G(z))=0.376]



Epoch 70 Summary:
G_loss: 6.3055 | D_loss: 0.5338
D_real: 0.5361 | D_fake: 0.5315
D(x): 0.621 | D(G(z)): 0.376
Val Loss: 0.3551


Epoch 71/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.3138, D=0.5335, D(x)=0.620, D(G(z))=0.379]



Epoch 71 Summary:
G_loss: 6.3138 | D_loss: 0.5335
D_real: 0.5355 | D_fake: 0.5316
D(x): 0.620 | D(G(z)): 0.379
Val Loss: 0.3362


Epoch 72/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.3605, D=0.5373, D(x)=0.623, D(G(z))=0.375]



Epoch 72 Summary:
G_loss: 6.3605 | D_loss: 0.5373
D_real: 0.5402 | D_fake: 0.5343
D(x): 0.623 | D(G(z)): 0.375
Val Loss: 0.3309


Epoch 73/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3396, D=0.5246, D(x)=0.627, D(G(z))=0.372]



Epoch 73 Summary:
G_loss: 6.3396 | D_loss: 0.5246
D_real: 0.5246 | D_fake: 0.5246
D(x): 0.627 | D(G(z)): 0.372
Val Loss: 0.3416


Epoch 74/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3509, D=0.5171, D(x)=0.629, D(G(z))=0.369]



Epoch 74 Summary:
G_loss: 6.3509 | D_loss: 0.5171
D_real: 0.5190 | D_fake: 0.5153
D(x): 0.629 | D(G(z)): 0.369
Val Loss: 0.3317


Epoch 75/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3438, D=0.5300, D(x)=0.626, D(G(z))=0.371]



Epoch 75 Summary:
G_loss: 6.3438 | D_loss: 0.5300
D_real: 0.5315 | D_fake: 0.5285
D(x): 0.626 | D(G(z)): 0.371
Val Loss: 0.3329


Epoch 76/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3304, D=0.5245, D(x)=0.628, D(G(z))=0.375]



Epoch 76 Summary:
G_loss: 6.3304 | D_loss: 0.5245
D_real: 0.5251 | D_fake: 0.5240
D(x): 0.628 | D(G(z)): 0.375
Val Loss: 0.3275


Epoch 77/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3294, D=0.5201, D(x)=0.630, D(G(z))=0.368]



Epoch 77 Summary:
G_loss: 6.3294 | D_loss: 0.5201
D_real: 0.5219 | D_fake: 0.5183
D(x): 0.630 | D(G(z)): 0.368
Val Loss: 0.3246


Epoch 78/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3572, D=0.5119, D(x)=0.634, D(G(z))=0.365]



Epoch 78 Summary:
G_loss: 6.3572 | D_loss: 0.5119
D_real: 0.5128 | D_fake: 0.5110
D(x): 0.634 | D(G(z)): 0.365
Val Loss: 0.3414


Epoch 79/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3716, D=0.5184, D(x)=0.631, D(G(z))=0.361]



Epoch 79 Summary:
G_loss: 6.3716 | D_loss: 0.5184
D_real: 0.5204 | D_fake: 0.5165
D(x): 0.631 | D(G(z)): 0.361
Val Loss: 0.3275


Epoch 80/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3710, D=0.5062, D(x)=0.639, D(G(z))=0.358]



Epoch 80 Summary:
G_loss: 6.3710 | D_loss: 0.5062
D_real: 0.5076 | D_fake: 0.5048
D(x): 0.639 | D(G(z)): 0.358
Val Loss: 0.3370


Epoch 81/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3573, D=0.5084, D(x)=0.639, D(G(z))=0.360]



Epoch 81 Summary:
G_loss: 6.3573 | D_loss: 0.5084
D_real: 0.5074 | D_fake: 0.5093
D(x): 0.639 | D(G(z)): 0.360
Val Loss: 0.3329


Epoch 82/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3457, D=0.4948, D(x)=0.645, D(G(z))=0.359]



Epoch 82 Summary:
G_loss: 6.3457 | D_loss: 0.4948
D_real: 0.4969 | D_fake: 0.4927
D(x): 0.645 | D(G(z)): 0.359
Val Loss: 0.3383


Epoch 83/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.3360, D=0.5042, D(x)=0.639, D(G(z))=0.356]



Epoch 83 Summary:
G_loss: 6.3360 | D_loss: 0.5042
D_real: 0.5068 | D_fake: 0.5016
D(x): 0.639 | D(G(z)): 0.356
Val Loss: 0.3309


Epoch 84/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.3156, D=0.4979, D(x)=0.646, D(G(z))=0.359]



Epoch 84 Summary:
G_loss: 6.3156 | D_loss: 0.4979
D_real: 0.4963 | D_fake: 0.4995
D(x): 0.646 | D(G(z)): 0.359
Val Loss: 0.3317


Epoch 85/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3574, D=0.4889, D(x)=0.648, D(G(z))=0.349]



Epoch 85 Summary:
G_loss: 6.3574 | D_loss: 0.4889
D_real: 0.4918 | D_fake: 0.4859
D(x): 0.648 | D(G(z)): 0.349
Val Loss: 0.3474


Epoch 86/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.3008, D=0.4918, D(x)=0.649, D(G(z))=0.352]



Epoch 86 Summary:
G_loss: 6.3008 | D_loss: 0.4918
D_real: 0.4906 | D_fake: 0.4929
D(x): 0.649 | D(G(z)): 0.352
Val Loss: 0.3344


Epoch 87/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3325, D=0.4835, D(x)=0.654, D(G(z))=0.343]



Epoch 87 Summary:
G_loss: 6.3325 | D_loss: 0.4835
D_real: 0.4855 | D_fake: 0.4816
D(x): 0.654 | D(G(z)): 0.343
Val Loss: 0.3324


Epoch 88/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.3266, D=0.4829, D(x)=0.654, D(G(z))=0.343]



Epoch 88 Summary:
G_loss: 6.3266 | D_loss: 0.4829
D_real: 0.4845 | D_fake: 0.4813
D(x): 0.654 | D(G(z)): 0.343
Val Loss: 0.3374


Epoch 89/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.3100, D=0.4777, D(x)=0.657, D(G(z))=0.343]



Epoch 89 Summary:
G_loss: 6.3100 | D_loss: 0.4777
D_real: 0.4780 | D_fake: 0.4774
D(x): 0.657 | D(G(z)): 0.343
Val Loss: 0.3369


Epoch 90/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.2978, D=0.4819, D(x)=0.654, D(G(z))=0.346]



Epoch 90 Summary:
G_loss: 6.2978 | D_loss: 0.4819
D_real: 0.4824 | D_fake: 0.4814
D(x): 0.654 | D(G(z)): 0.346
Val Loss: 0.3337


Epoch 91/100: 100%|██████████| 264/264 [01:18<00:00,  3.34it/s, G=6.2838, D=0.4736, D(x)=0.658, D(G(z))=0.346]



Epoch 91 Summary:
G_loss: 6.2838 | D_loss: 0.4736
D_real: 0.4752 | D_fake: 0.4720
D(x): 0.658 | D(G(z)): 0.346
Val Loss: 0.3321


Epoch 92/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2943, D=0.4653, D(x)=0.664, D(G(z))=0.334]



Epoch 92 Summary:
G_loss: 6.2943 | D_loss: 0.4653
D_real: 0.4647 | D_fake: 0.4659
D(x): 0.664 | D(G(z)): 0.334
Val Loss: 0.3347


Epoch 93/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2749, D=0.4704, D(x)=0.661, D(G(z))=0.336]



Epoch 93 Summary:
G_loss: 6.2749 | D_loss: 0.4704
D_real: 0.4710 | D_fake: 0.4698
D(x): 0.661 | D(G(z)): 0.336
Val Loss: 0.3286


Epoch 94/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2509, D=0.4661, D(x)=0.663, D(G(z))=0.342]



Epoch 94 Summary:
G_loss: 6.2509 | D_loss: 0.4661
D_real: 0.4666 | D_fake: 0.4656
D(x): 0.663 | D(G(z)): 0.342
Val Loss: 0.3334


Epoch 95/100: 100%|██████████| 264/264 [01:19<00:00,  3.32it/s, G=6.3143, D=0.4753, D(x)=0.656, D(G(z))=0.341]



Epoch 95 Summary:
G_loss: 6.3143 | D_loss: 0.4753
D_real: 0.4755 | D_fake: 0.4752
D(x): 0.656 | D(G(z)): 0.341
Val Loss: 0.3326


Epoch 96/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.2565, D=0.4695, D(x)=0.660, D(G(z))=0.340]



Epoch 96 Summary:
G_loss: 6.2565 | D_loss: 0.4695
D_real: 0.4704 | D_fake: 0.4686
D(x): 0.660 | D(G(z)): 0.340
Val Loss: 0.3332


Epoch 97/100: 100%|██████████| 264/264 [01:19<00:00,  3.34it/s, G=6.2411, D=0.4647, D(x)=0.663, D(G(z))=0.337]



Epoch 97 Summary:
G_loss: 6.2411 | D_loss: 0.4647
D_real: 0.4657 | D_fake: 0.4638
D(x): 0.663 | D(G(z)): 0.337
Val Loss: 0.3344


Epoch 98/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.2421, D=0.4583, D(x)=0.670, D(G(z))=0.339]



Epoch 98 Summary:
G_loss: 6.2421 | D_loss: 0.4583
D_real: 0.4575 | D_fake: 0.4592
D(x): 0.670 | D(G(z)): 0.339
Val Loss: 0.3346


Epoch 99/100: 100%|██████████| 264/264 [01:18<00:00,  3.35it/s, G=6.2323, D=0.4565, D(x)=0.671, D(G(z))=0.333]



Epoch 99 Summary:
G_loss: 6.2323 | D_loss: 0.4565
D_real: 0.4549 | D_fake: 0.4582
D(x): 0.671 | D(G(z)): 0.333
Val Loss: 0.3282


Epoch 100/100: 100%|██████████| 264/264 [01:19<00:00,  3.33it/s, G=6.2235, D=0.4590, D(x)=0.666, D(G(z))=0.337]



Epoch 100 Summary:
G_loss: 6.2235 | D_loss: 0.4590
D_real: 0.4599 | D_fake: 0.4582
D(x): 0.666 | D(G(z)): 0.337
Val Loss: 0.3323


In [11]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [12]:
len(test_loader)

264

In [17]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.utils as vutils
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np

def inference(checkpoint_path, input_dir, target_dir, save_dir='predictions', num_samples=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Загрузка модели
    netG = UNet().to(device)
    netG.load_state_dict(torch.load(checkpoint_path, map_location=device))
    netG.eval()


    os.makedirs(save_dir, exist_ok=True)
    mask_dir = os.path.join(save_dir, 'predicted_masks')
    os.makedirs(mask_dir, exist_ok=True)

    with torch.no_grad():
        for idx, (input_img, target_img) in tqdm(enumerate(test_loader), total=len(test_loader)):
            input_img = input_img.to(device)
            fake_mask = netG(input_img)

            # Сопоставим имя файла с оригинальным индексом
            real_idx = train_size + idx
            base_name = os.path.splitext(dataset.filenames[real_idx])[0]  # например: "img_012"

            # Обработка изображений
            fake_mask_img = fake_mask.squeeze(0).squeeze(0).cpu()       # [1, 1, H, W] → [H, W]
            input_vis = input_img.squeeze(0).cpu()                      # [1, C, H, W] → [C, H, W]
            target_vis = target_img.squeeze(0).squeeze(0).cpu()         # [1, 1, H, W] → [H, W]

            # Сохраняем отдельную предсказанную маску
            plt.imsave(os.path.join(mask_dir, f'{base_name}_pred.png'), fake_mask_img, cmap='gray')

            # Визуализация
            fig, axs = plt.subplots(1, 3, figsize=(12, 4))

            if input_vis.shape[0] == 3:  # RGB
                axs[0].imshow(input_vis.permute(1, 2, 0) * 0.5 + 0.5)
            else:  # Grayscale
                axs[0].imshow(input_vis.squeeze(), cmap='gray')
            axs[0].set_title('Input')

            axs[1].imshow(target_vis, cmap='gray')
            axs[1].set_title('Ground Truth')

            axs[2].imshow(fake_mask_img, cmap='gray')
            axs[2].set_title('Predicted Mask')

            for ax in axs:
                ax.axis('off')

            plt.tight_layout()
            plt.savefig(f'{save_dir}/sample_{idx:03d}.png')
            plt.close()

            if idx + 1 >= num_samples:
                break


if __name__ == '__main__':
    inference(
        checkpoint_path='generator_final.pth',
        input_dir='/kaggle/input/dataset1/sliced images/input',
        target_dir='/kaggle/input/dataset1/sliced images/target',
        save_dir='predictions',
        num_samples=264
    )

  netG.load_state_dict(torch.load(checkpoint_path, map_location=device))
100%|█████████▉| 263/264 [01:41<00:00,  2.58it/s]


In [23]:
import torch
from torch.utils.data import DataLoader, random_split
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as ssim_func
from math import log10
import csv

def compute_metrics(pred, target):
    pred = (pred > 0).astype(np.uint8)
    target = (target > 0).astype(np.uint8)

    intersection = np.logical_and(pred, target).sum()
    union = np.logical_or(pred, target).sum()

    dice = 2 * intersection / (pred.sum() + target.sum() + 1e-8)
    iou = intersection / (union + 1e-8)
    mae = np.mean(np.abs(pred - target))
    rmse = np.sqrt(np.mean((pred - target) ** 2))
    ssim = ssim_func(target, pred, data_range=1)

    mse = np.mean((pred - target) ** 2)
    psnr = 20 * log10(1.0) - 10 * log10(mse + 1e-8)  # Assuming pixel values in [0, 1]

    return dice, iou, mae, rmse, ssim, psnr

def inference_with_metrics(checkpoint_path, input_dir, target_dir, save_dir='predictions'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Загрузка модели
    netG = UNet().to(device)
    netG.load_state_dict(torch.load(checkpoint_path, map_location=device))
    netG.eval()

    os.makedirs(save_dir, exist_ok=True)

    total_dice = total_iou = total_mae = total_rmse = total_ssim = total_psnr = 0

    with torch.no_grad():
        for idx, (input_img, target_img) in tqdm(enumerate(test_loader), total=len(test_loader)):
            input_img = input_img.to(device)
            fake_mask = netG(input_img)

            fake_mask_np = fake_mask.squeeze().cpu().numpy()
            target_np = target_img.squeeze().cpu().numpy()

            # Бинаризация
            fake_mask_np_bin = (fake_mask_np > 0).astype(np.uint8)
            target_np_bin = (target_np > 0).astype(np.uint8)

            # Метрики
            dice, iou, mae, rmse, ssim, psnr = compute_metrics(fake_mask_np_bin, target_np_bin)
            total_dice += dice
            total_iou += iou
            total_mae += mae
            total_rmse += rmse
            total_ssim += ssim
            total_psnr += psnr

    N = len(test_loader)
    avg_dice = total_dice / N
    avg_iou = total_iou / N
    avg_mae = total_mae / N
    avg_rmse = total_rmse / N
    avg_ssim = total_ssim / N
    avg_psnr = total_psnr / N

    print(f"\n📊 Metrics on Test Set:")
    print(f"Dice:  {avg_dice:.4f}")
    print(f"IoU:   {avg_iou:.4f}")
    print(f"MAE:   {avg_mae:.4f}")
    print(f"RMSE:  {avg_rmse:.4f}")
    print(f"SSIM:  {avg_ssim:.4f}")
    print(f"PSNR:  {avg_psnr:.4f}")

    # Сохранение в CSV
    csv_path = os.path.join(save_dir, 'metrics.csv')
    with open(csv_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Metric', 'Value'])
        writer.writerow(['Dice', avg_dice])
        writer.writerow(['IoU', avg_iou])
        writer.writerow(['MAE', avg_mae])
        writer.writerow(['RMSE', avg_rmse])
        writer.writerow(['SSIM', avg_ssim])
        writer.writerow(['PSNR', avg_psnr])

if __name__ == '__main__':
    inference_with_metrics(
        checkpoint_path='generator_final.pth',
        input_dir='/kaggle/input/dataset1/sliced images/input',
        target_dir='/kaggle/input/dataset1/sliced images/target',
        save_dir='predictions'
    )

  netG.load_state_dict(torch.load(checkpoint_path, map_location=device))
100%|██████████| 264/264 [00:18<00:00, 13.91it/s]


📊 Metrics on Test Set:
Dice:  0.8333
IoU:   0.7273
MAE:   21.0411
RMSE:  0.3929
SSIM:  0.5130
PSNR:  8.3518





In [24]:
import shutil

shutil.make_archive('/kaggle/working/output', 'zip', '/kaggle/working')

'/kaggle/working/output.zip'