In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import save_image
import os
import random
from PIL import Image
import numpy as np
from tqdm import tqdm
import json
from matplotlib import pyplot as plt
from torch.cuda.amp import autocast, GradScaler

In [2]:
class SketchDataset(Dataset):
    def __init__(self, bad_sketch_dir, good_sketch_dir):
        self.bad_sketch_paths = sorted(
            [os.path.join(bad_sketch_dir, x) for x in os.listdir(bad_sketch_dir)]
        )
        self.good_sketch_paths = sorted(
            [os.path.join(good_sketch_dir, x) for x in os.listdir(good_sketch_dir)]
        )
        self.transform = transforms.Compose(
            [
                transforms.Grayscale(),
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )

        self.transform_augment = transforms.Compose(
            [
                transforms.Grayscale(num_output_channels=1),
                transforms.RandomHorizontalFlip(),
                transforms.RandomAffine(10, shear=10, scale=(0.8, 1.2), fill=255),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )

    def __getitem__(self, index):
        bad_sketch = Image.open(
            self.bad_sketch_paths[index % len(self.bad_sketch_paths)]
        )

        good_sketch = Image.open(random.choice(self.good_sketch_paths))

        bad_sketch = self.transform(bad_sketch)
        good_sketch = self.transform(good_sketch)

        return {"A": bad_sketch, "B": good_sketch}

    def __len__(self):
        return len(self.bad_sketch_paths)

In [3]:
# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # Initial convolution block
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Downsampling blocks
        self.down_blocks = nn.Sequential(
            self._make_layer(64, 128), self._make_layer(128, 256)
        )

        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(9)])

        # Upsampling blocks
        self.up_blocks = nn.Sequential(
            self._make_layer(256, 128, upsample=True),
            self._make_layer(128, 64, upsample=True),
        )

        # Output convolution
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=7, stride=1, padding=3), nn.Tanh()
        )

    def _make_layer(self, in_channels, out_channels, upsample=False):
        if upsample:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels, out_channels, 3, stride=2, padding=1, output_padding=1
                ),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.down_blocks(x)
        x = self.res_blocks(x)
        x = self.up_blocks(x)
        return self.conv2(x)


# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1),
        )

    def forward(self, x):
        return self.model(x)

In [4]:
# Training class for each sketch category
class SketchEnhancer:
    def __init__(
        self,
        class_name,
        train_set,
        val_set,
        test_set,
        batch_size=64,
        device="cuda",
        log_dir="logs",
    ):
        self.class_name = class_name
        self.device = device
        self.log_dir = os.path.join(log_dir, class_name)
        os.makedirs(self.log_dir, exist_ok=True)

        # Initialize networks
        self.G_AB = Generator().to(device)
        self.G_BA = Generator().to(device)
        self.D_A = Discriminator().to(device)
        self.D_B = Discriminator().to(device)

        # Initialize optimizers
        self.g_optimizer = optim.Adam(
            list(self.G_AB.parameters()) + list(self.G_BA.parameters()),
            lr=0.0002,
            betas=(0.5, 0.999),
        )
        self.d_optimizer = optim.Adam(
            list(self.D_A.parameters()) + list(self.D_B.parameters()),
            lr=0.0002,
            betas=(0.5, 0.999),
        )

        # Loss functions
        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()

        self.train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        self.val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

        self.g_scaler = GradScaler()
        self.d_scaler = GradScaler()

        # Initialize training history
        self.history = {
            "train_losses": [],
            "val_losses": [],
            "best_val_loss": float("inf"),
        }

    def compute_losses(self, batch, use_amp=False):
        real_A = batch["A"].to(self.device)
        real_B = batch["B"].to(self.device)

        with autocast(enabled=use_amp):
            # Generate fake samples
            fake_B = self.G_AB(real_A)
            fake_A = self.G_BA(real_B)

            # Reconstruct samples
            rec_A = self.G_BA(fake_B)
            rec_B = self.G_AB(fake_A)

            # Identity mapping
            id_A = self.G_BA(real_A)
            id_B = self.G_AB(real_B)

            # Calculate losses
            loss_id_A = self.criterion_identity(id_A, real_A) * 5.0
            loss_id_B = self.criterion_identity(id_B, real_B) * 5.0

            # GAN losses
            pred_fake_B = self.D_B(fake_B)
            loss_GAN_AB = self.criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))
            pred_fake_A = self.D_A(fake_A)
            loss_GAN_BA = self.criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))

            # Cycle losses
            loss_cycle_A = self.criterion_cycle(rec_A, real_A) * 10.0
            loss_cycle_B = self.criterion_cycle(rec_B, real_B) * 10.0

            # Total losses
            loss_G = (
                loss_GAN_AB
                + loss_GAN_BA
                + loss_cycle_A
                + loss_cycle_B
                + loss_id_A
                + loss_id_B
            )

            # Discriminator losses
            pred_real_A = self.D_A(real_A)
            loss_D_real_A = self.criterion_GAN(
                pred_real_A, torch.ones_like(pred_real_A)
            )
            pred_real_B = self.D_B(real_B)
            loss_D_real_B = self.criterion_GAN(
                pred_real_B, torch.ones_like(pred_real_B)
            )

            pred_fake_A = self.D_A(fake_A.detach())
            loss_D_fake_A = self.criterion_GAN(
                pred_fake_A, torch.zeros_like(pred_fake_A)
            )
            pred_fake_B = self.D_B(fake_B.detach())
            loss_D_fake_B = self.criterion_GAN(
                pred_fake_B, torch.zeros_like(pred_fake_B)
            )

            loss_D = (loss_D_real_A + loss_D_fake_A) * 0.5 + (
                loss_D_real_B + loss_D_fake_B
            ) * 0.5

        return {
            "loss_G": loss_G,
            "loss_D": loss_D,
            "loss_cycle": loss_cycle_A + loss_cycle_B,
            "loss_identity": loss_id_A + loss_id_B,
            "samples": (real_A, fake_B, rec_A),
        }

    def validate(self):
        self.G_AB.eval()
        self.G_BA.eval()
        self.D_A.eval()
        self.D_B.eval()

        val_losses = []
        with torch.no_grad(), autocast():
            for batch in self.val_loader:
                losses = self.compute_losses(batch, use_amp=True)
                val_losses.append(losses["loss_G"].item() + losses["loss_D"].item())

        return np.mean(val_losses)

    def train(self, num_epochs):
        # Create progress bars

        for epoch in range(num_epochs):
            # Training phase
            self.G_AB.train()
            self.G_BA.train()
            self.D_A.train()
            self.D_B.train()

            train_losses = []
            batch_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}", leave=False)

            for batch in batch_bar:
                # Compute losses with AMP
                losses = self.compute_losses(batch, use_amp=True)

                # Update generators with AMP
                self.g_optimizer.zero_grad()
                self.g_scaler.scale(losses["loss_G"]).backward()
                self.g_scaler.step(self.g_optimizer)
                self.g_scaler.update()

                # Update discriminators with AMP
                self.d_optimizer.zero_grad()
                self.d_scaler.scale(losses["loss_D"]).backward()
                self.d_scaler.step(self.d_optimizer)
                self.d_scaler.update()

                # Update progress bar
                batch_bar.set_postfix(
                    {
                        "G_loss": f"{losses['loss_G'].item():.4f}",
                        "D_loss": f"{losses['loss_D'].item():.4f}",
                    }
                )

                train_losses.append(losses["loss_G"].item() + losses["loss_D"].item())

            # Validation phase
            val_loss = self.validate()

            # Update training history
            avg_train_loss = np.mean(train_losses)
            self.history["train_losses"].append(avg_train_loss)
            self.history["val_losses"].append(val_loss)

            # Update progress bar
            print(
                f"Epoch[{epoch+1}] Train Loss: {avg_train_loss:.4f}   Validation Loss: {val_loss:.4f}"
            )

            # Save best model
            if val_loss < self.history["best_val_loss"]:
                self.history["best_val_loss"] = val_loss
                self.save_models(best=True)

            # Save sample images
            if (epoch + 1) % 5 == 0:
                self.save_samples(*losses["samples"], epoch)

            # Save training history
            self.save_history()

    def save_samples(self, real_A, fake_B, rec_A, epoch):
        save_image(
            torch.cat([real_A, fake_B, rec_A], 0),
            f"samples_{self.class_name}_epoch_{epoch}.png",
            normalize=True,
        )

    def save_history(self):
        with open(os.path.join(self.log_dir, "training_history.json"), "w") as f:
            json.dump(self.history, f)

    def save_models(self, best=False):
        prefix = "best_" if best else ""
        torch.save(
            self.G_AB.state_dict(), os.path.join(self.log_dir, f"{prefix}G_AB.pth")
        )
        torch.save(
            self.G_BA.state_dict(), os.path.join(self.log_dir, f"{prefix}G_BA.pth")
        )
        torch.save(
            self.D_A.state_dict(), os.path.join(self.log_dir, f"{prefix}D_A.pth")
        )
        torch.save(
            self.D_B.state_dict(), os.path.join(self.log_dir, f"{prefix}D_B.pth")
        )

    def test(self):
        self.G_AB.eval()
        self.G_BA.eval()
        self.D_A.eval()
        self.D_B.eval()

        test_losses = []
        with torch.no_grad(), autocast():
            for batch in self.test_loader:
                losses = self.compute_losses(batch, use_amp=True)
                test_losses.append(losses["loss_G"].item() + losses["loss_D"].item())

        return np.mean(test_losses)

    def load_models(self):
        self.G_AB.load_state_dict(torch.load(f"G_AB_{self.class_name}.pth"))
        self.G_BA.load_state_dict(torch.load(f"G_BA_{self.class_name}.pth"))
        self.D_A.load_state_dict(torch.load(f"D_A_{self.class_name}.pth"))
        self.D_B.load_state_dict(torch.load(f"D_B_{self.class_name}.pth"))


# Main training loop for all classes
def train_all_classes(base_dir, classes, batch_size=64, num_epochs=100):
    for class_name in classes:
        print(f"\nTraining class: {class_name}")

        # Setup data paths
        bad_sketch_dir = os.path.join(base_dir, "bad_sketches", class_name)
        good_sketch_dir = os.path.join(base_dir, "good_sketches", class_name)

        # Create dataset and dataloader
        dataset = SketchDataset(bad_sketch_dir, good_sketch_dir)
        print(f"Number of samples: {len(dataset)}")
        test_split = 0.2
        val_split = 0.1

        train_set, test_set = random_split(
            dataset,
            [
                round(len(dataset) * (1 - test_split)),
                round(len(dataset) * test_split),
            ],
        )
        train_set, val_set = random_split(
            train_set,
            [
                round(len(train_set) * (1 - val_split)),
                round(len(train_set) * val_split),
            ],
        )
        print(f"Train set: {len(train_set)}")
        print(f"Validation set: {len(val_set)}")
        print(f"Test set: {len(test_set)}")

        # Initialize and train model
        enhancer = SketchEnhancer(class_name, train_set, val_set, test_set, batch_size)
        enhancer.train(num_epochs)
        enhancer.save_models()

        # Test model
        test_loss = enhancer.test()
        print(f"Test loss: {test_loss:.4f}")

In [None]:
base_dir = "./Datasets/"
classes = os.listdir(os.path.join(base_dir, "bad_sketches"))

train_all_classes(base_dir, classes, batch_size=16, num_epochs=50)


Training class: airplane
Number of samples: 3001
Train set: 2161
Validation set: 240
Test set: 600


                                                                                        

Epoch[1] Train Loss: 4.8358   Validation Loss: 1.5551


                                                                                        

Epoch[2] Train Loss: 1.8084   Validation Loss: 1.8282


                                                                                        

Epoch[3] Train Loss: 1.7786   Validation Loss: 2.6143


                                                                                        

Epoch[4] Train Loss: 1.8100   Validation Loss: 1.7792


                                                                                        

Epoch[5] Train Loss: 1.7594   Validation Loss: 1.9280


                                                                                        

Epoch[6] Train Loss: 1.8107   Validation Loss: 2.2160


                                                                                        

Epoch[7] Train Loss: 1.8760   Validation Loss: 1.9720


                                                                                        

Epoch[8] Train Loss: 1.8580   Validation Loss: 1.9805


                                                                                        

Epoch[9] Train Loss: 1.8270   Validation Loss: 2.3896


                                                                                         

Epoch[10] Train Loss: 1.8267   Validation Loss: 2.2717


                                                                                         

Epoch[11] Train Loss: 1.8631   Validation Loss: 2.1189


                                                                                         

Epoch[12] Train Loss: 1.9225   Validation Loss: 1.3342


                                                                                         

Epoch[13] Train Loss: 1.8953   Validation Loss: 2.2555


                                                                                         

Epoch[14] Train Loss: 1.9235   Validation Loss: 2.7992


                                                                                         

Epoch[15] Train Loss: 1.9913   Validation Loss: 1.8029


                                                                                         

Epoch[16] Train Loss: 1.8926   Validation Loss: 1.6126


                                                                                         

Epoch[17] Train Loss: 1.9543   Validation Loss: 2.0823


                                                                                         

Epoch[18] Train Loss: 2.1493   Validation Loss: 1.6429


                                                                                         

Epoch[19] Train Loss: 1.7145   Validation Loss: 1.6621


                                                                                          

Epoch[20] Train Loss: 3.4054   Validation Loss: 4.4848


                                                                                         

Epoch[21] Train Loss: 3.4911   Validation Loss: 4.2627


                                                                                         

Epoch[22] Train Loss: 3.0622   Validation Loss: 2.0725


                                                                                         

Epoch[23] Train Loss: 1.8419   Validation Loss: 2.3591


                                                                                         

Epoch[24] Train Loss: 1.9722   Validation Loss: 2.6252


                                                                                         

Epoch[25] Train Loss: 1.8294   Validation Loss: 1.8857


                                                                                         

Epoch[26] Train Loss: 1.9088   Validation Loss: 1.6854


                                                                                         

Epoch[27] Train Loss: 1.8373   Validation Loss: 2.1851


                                                                                         

Epoch[28] Train Loss: 1.9297   Validation Loss: 2.0278


                                                                                         

Epoch[29] Train Loss: 1.9579   Validation Loss: 1.8961


                                                                                         

Epoch[30] Train Loss: 1.9707   Validation Loss: 3.7540


                                                                                         

Epoch[31] Train Loss: 1.9881   Validation Loss: 2.9430


                                                                                         

Epoch[32] Train Loss: 1.8317   Validation Loss: 1.7930


                                                                                         

Epoch[33] Train Loss: 1.8114   Validation Loss: 1.7211


                                                                                         

Epoch[34] Train Loss: 1.9070   Validation Loss: 2.2553


                                                                                         

Epoch[35] Train Loss: 1.9347   Validation Loss: 1.8700


                                                                                         

Epoch[36] Train Loss: 1.8829   Validation Loss: 1.3961


                                                                                         

Epoch[37] Train Loss: 2.1292   Validation Loss: 2.4623


                                                                                         

Epoch[38] Train Loss: 2.9823   Validation Loss: 2.8649


                                                                                         

Epoch[39] Train Loss: 3.0607   Validation Loss: 2.9910


                                                                                         

Epoch[40] Train Loss: 2.6563   Validation Loss: 3.5299


                                                                                         

Epoch[41] Train Loss: 2.6315   Validation Loss: 3.2935


                                                                                         

Epoch[42] Train Loss: 2.2577   Validation Loss: 2.1259


                                                                                         

Epoch[43] Train Loss: 2.0203   Validation Loss: 3.4667


                                                                                         

Epoch[44] Train Loss: 2.0989   Validation Loss: 1.6451


                                                                                         

Epoch[45] Train Loss: 2.1131   Validation Loss: 2.5956


                                                                                         

Epoch[46] Train Loss: 2.1186   Validation Loss: 1.7257


                                                                                         

Epoch[47] Train Loss: 2.0792   Validation Loss: 2.0684


                                                                                         

Epoch[48] Train Loss: 2.1595   Validation Loss: 2.5129


                                                                                         

Epoch[49] Train Loss: 2.1941   Validation Loss: 2.8839


                                                                                         

Epoch[50] Train Loss: 2.1370   Validation Loss: 2.1500
Test loss: 2.1170

Training class: alarm_clock
Number of samples: 3001
Train set: 2161
Validation set: 240
Test set: 600


                                                                                        

Epoch[1] Train Loss: 4.3702   Validation Loss: 1.6322


                                                                                        

Epoch[2] Train Loss: 2.0119   Validation Loss: 1.6903


                                                                                        

Epoch[3] Train Loss: 1.9737   Validation Loss: 3.3347


                                                                                        

Epoch[4] Train Loss: 2.2546   Validation Loss: 1.5641


                                                                                        

Epoch[5] Train Loss: 1.8718   Validation Loss: 3.2188


                                                                                        

Epoch[6] Train Loss: 1.9387   Validation Loss: 2.1562


                                                                                        

Epoch[7] Train Loss: 1.9911   Validation Loss: 4.2210


                                                                                        

Epoch[8] Train Loss: 1.9275   Validation Loss: 2.1977


                                                                                        

Epoch[9] Train Loss: 1.9422   Validation Loss: 1.5920


                                                                                         

Epoch[10] Train Loss: 2.0216   Validation Loss: 3.2037


                                                                                         

Epoch[11] Train Loss: 1.9802   Validation Loss: 2.5640


                                                                                         

Epoch[12] Train Loss: 2.0895   Validation Loss: 2.1453


                                                                                         

Epoch[13] Train Loss: 1.9806   Validation Loss: 2.6982


                                                                                         

Epoch[14] Train Loss: 1.9964   Validation Loss: 2.6850


                                                                                         

Epoch[15] Train Loss: 2.0114   Validation Loss: 2.5509


                                                                                         

Epoch[16] Train Loss: 2.0769   Validation Loss: 2.5410


                                                                                         

Epoch[17] Train Loss: 2.1039   Validation Loss: 2.9092


                                                                                         

Epoch[18] Train Loss: 2.0890   Validation Loss: 2.0301


                                                                                         

Epoch[19] Train Loss: 2.0982   Validation Loss: 2.1375


                                                                                         

Epoch[20] Train Loss: 2.1498   Validation Loss: 2.0118


                                                                                         

Epoch[21] Train Loss: 2.1365   Validation Loss: 2.6870


                                                                                         

Epoch[22] Train Loss: 2.1178   Validation Loss: 1.8217


                                                                                         

Epoch[23] Train Loss: 2.0773   Validation Loss: 2.7722


                                                                                         

Epoch[24] Train Loss: 2.1235   Validation Loss: 1.5197


                                                                                         

Epoch[25] Train Loss: 2.0378   Validation Loss: 2.5030


                                                                                         

Epoch[26] Train Loss: 2.1497   Validation Loss: 5.6968


                                                                                         

Epoch[27] Train Loss: 4.3493   Validation Loss: 2.0719


                                                                                         

Epoch[28] Train Loss: 1.6766   Validation Loss: 1.8682


                                                                                         

Epoch[29] Train Loss: 1.8396   Validation Loss: 1.3236


                                                                                         

Epoch[30] Train Loss: 2.0103   Validation Loss: 1.8039


                                                                                         

Epoch[31] Train Loss: 2.0177   Validation Loss: 3.0856


                                                                                         

Epoch[32] Train Loss: 2.0821   Validation Loss: 2.2225


                                                                                         

Epoch[33] Train Loss: 2.1015   Validation Loss: 1.8218


                                                                                         

Epoch[34] Train Loss: 2.0488   Validation Loss: 3.3576


                                                                                         

Epoch[35] Train Loss: 2.1080   Validation Loss: 2.0450


                                                                                         

Epoch[36] Train Loss: 2.2374   Validation Loss: 2.3336


                                                                                         

Epoch[37] Train Loss: 2.1636   Validation Loss: 2.0651


                                                                                         

Epoch[38] Train Loss: 2.1481   Validation Loss: 2.2964


                                                                                         

Epoch[39] Train Loss: 2.1491   Validation Loss: 2.1076


                                                                                         

Epoch[40] Train Loss: 2.1512   Validation Loss: 2.0903


                                                                                         

Epoch[41] Train Loss: 1.9935   Validation Loss: 3.6871


                                                                                         

Epoch[42] Train Loss: 2.1196   Validation Loss: 2.5514


                                                                                         

Epoch[43] Train Loss: 2.1533   Validation Loss: 2.1317


                                                                                         

Epoch[44] Train Loss: 2.0636   Validation Loss: 2.2983


                                                                                         

Epoch[45] Train Loss: 2.1323   Validation Loss: 3.2056


                                                                                         

Epoch[46] Train Loss: 2.2486   Validation Loss: 2.1186


                                                                                         

Epoch[47] Train Loss: 2.1004   Validation Loss: 2.5607


                                                                                         

Epoch[48] Train Loss: 2.1671   Validation Loss: 2.3042


                                                                                         

Epoch[49] Train Loss: 2.2093   Validation Loss: 2.0329


                                                                                         

Epoch[50] Train Loss: 2.0415   Validation Loss: 3.7839
Test loss: 3.7895

Training class: axe
Number of samples: 3001
Train set: 2161
Validation set: 240
Test set: 600


                                                                                        

Epoch[1] Train Loss: 4.2227   Validation Loss: 1.9144


                                                                                        

Epoch[2] Train Loss: 1.8389   Validation Loss: 2.2346


                                                                                        

Epoch[3] Train Loss: 1.7929   Validation Loss: 1.7472


                                                                                        

Epoch[4] Train Loss: 1.7415   Validation Loss: 1.9081


                                                                                        

Epoch[5] Train Loss: 1.7681   Validation Loss: 1.6068


                                                                                        

Epoch[6] Train Loss: 1.7574   Validation Loss: 1.5711


                                                                                        

Epoch[7] Train Loss: 1.7680   Validation Loss: 2.2158


                                                                                        

Epoch[8] Train Loss: 1.8122   Validation Loss: 2.9800


                                                                                        

Epoch[9] Train Loss: 2.4278   Validation Loss: 2.8246


                                                                                         

Epoch[10] Train Loss: 2.5392   Validation Loss: 2.5076


                                                                                         

Epoch[11] Train Loss: 2.6123   Validation Loss: 4.7860


                                                                                         

Epoch[12] Train Loss: 2.5985   Validation Loss: 2.4125


                                                                                         

Epoch[13] Train Loss: 2.5956   Validation Loss: 4.0747


                                                                                         

Epoch[14] Train Loss: 3.1924   Validation Loss: 3.2007


                                                                                         

Epoch[15] Train Loss: 3.2020   Validation Loss: 3.2410


                                                                                         

Epoch[16] Train Loss: 3.1826   Validation Loss: 3.1993


                                                                                         

Epoch[17] Train Loss: 3.1953   Validation Loss: 3.1813


                                                                                         

Epoch[18] Train Loss: 3.0888   Validation Loss: 2.8828


                                                                                         

Epoch[19] Train Loss: 2.7417   Validation Loss: 3.2224


                                                                                         

Epoch[20] Train Loss: 2.7748   Validation Loss: 2.9270


                                                                                         

Epoch[21] Train Loss: 2.7663   Validation Loss: 2.4340


                                                                                         

Epoch[22] Train Loss: 2.5544   Validation Loss: 3.0841


                                                                                         

Epoch[23] Train Loss: 2.7073   Validation Loss: 2.8995


                                                                                         

Epoch[24] Train Loss: 2.7201   Validation Loss: 2.4213


                                                                                         

Epoch[25] Train Loss: 2.6692   Validation Loss: 2.9324


                                                                                         

Epoch[26] Train Loss: 2.7401   Validation Loss: 3.2322


                                                                                         

Epoch[27] Train Loss: 2.7710   Validation Loss: 2.5110


                                                                                         

Epoch[28] Train Loss: 2.7421   Validation Loss: 2.7550


                                                                                         

Epoch[29] Train Loss: 2.7154   Validation Loss: 2.6178


                                                                                          

Epoch[30] Train Loss: 3.1557   Validation Loss: 2.6771


                                                                                         

Epoch[31] Train Loss: 2.5795   Validation Loss: 3.4215


                                                                                         

Epoch[32] Train Loss: 2.6797   Validation Loss: 2.4662


                                                                                         

Epoch[33] Train Loss: 2.6915   Validation Loss: 3.0031


                                                                                         

Epoch[34] Train Loss: 2.8145   Validation Loss: 3.8611


                                                                                         

Epoch[35] Train Loss: 2.7763   Validation Loss: 4.2202


                                                                                         

Epoch[36] Train Loss: 2.7946   Validation Loss: 2.6940


                                                                                         

Epoch[37] Train Loss: 2.9335   Validation Loss: 2.9299


                                                                                         

Epoch[38] Train Loss: 2.8210   Validation Loss: 2.4539


                                                                                         

Epoch[39] Train Loss: 2.6930   Validation Loss: 2.7800


                                                                                         

Epoch[40] Train Loss: 2.7140   Validation Loss: 3.1003


                                                                                         

Epoch[41] Train Loss: 2.7441   Validation Loss: 3.2743


                                                                                         

Epoch[42] Train Loss: 2.8839   Validation Loss: 4.0681


                                                                                         

Epoch[43] Train Loss: 2.8374   Validation Loss: 4.1243


                                                                                         

Epoch[44] Train Loss: 2.8022   Validation Loss: 3.3805


                                                                                         

Epoch[45] Train Loss: 2.8288   Validation Loss: 2.7500


                                                                                         

Epoch[46] Train Loss: 2.8299   Validation Loss: 2.4190


                                                                                         

Epoch[47] Train Loss: 2.7930   Validation Loss: 3.0452


                                                                                         

Epoch[48] Train Loss: 2.7753   Validation Loss: 3.1066


                                                                                         

Epoch[49] Train Loss: 2.7513   Validation Loss: 2.4354


                                                                                         

Epoch[50] Train Loss: 2.7779   Validation Loss: 2.5859
Test loss: 2.5721

Training class: bicycle
Number of samples: 3001
Train set: 2161
Validation set: 240
Test set: 600


                                                                                        

Epoch[1] Train Loss: 5.1880   Validation Loss: 2.4510


                                                                                        

Epoch[2] Train Loss: 2.3186   Validation Loss: 2.2861


                                                                                        

Epoch[3] Train Loss: 2.2191   Validation Loss: 1.8520


                                                                                        

Epoch[4] Train Loss: 2.1572   Validation Loss: 4.2069


                                                                                        

Epoch[5] Train Loss: 2.1015   Validation Loss: 1.6488


                                                                                        

Epoch[6] Train Loss: 1.9060   Validation Loss: 1.7290


                                                                                        

Epoch[7] Train Loss: 2.4015   Validation Loss: 1.6770


                                                                                        

Epoch[8] Train Loss: 2.1323   Validation Loss: 2.0502


                                                                                        

Epoch[9] Train Loss: 2.0083   Validation Loss: 2.1493


                                                                                         

Epoch[10] Train Loss: 2.1684   Validation Loss: 2.3411


                                                                                         

Epoch[11] Train Loss: 2.0342   Validation Loss: 1.9395


                                                                                         

Epoch[12] Train Loss: 2.0475   Validation Loss: 2.3252


                                                                                         

Epoch[13] Train Loss: 2.0680   Validation Loss: 2.0730


                                                                                         

Epoch[14] Train Loss: 2.1008   Validation Loss: 2.0245


                                                                                         

Epoch[15] Train Loss: 2.1024   Validation Loss: 2.6811


                                                                                         

Epoch[16] Train Loss: 2.1018   Validation Loss: 3.7566


                                                                                         

Epoch[17] Train Loss: 2.0864   Validation Loss: 1.5300


                                                                                         

Epoch[18] Train Loss: 2.1190   Validation Loss: 2.0968


                                                                                         

Epoch[19] Train Loss: 2.3276   Validation Loss: 3.4748


                                                                                         

Epoch[20] Train Loss: 2.1554   Validation Loss: 2.4416


                                                                                         

Epoch[21] Train Loss: 2.3807   Validation Loss: 2.7922


                                                                                           

Epoch[22] Train Loss: 3.9156   Validation Loss: 5.1718


                                                                                         

Epoch[23] Train Loss: 2.7808   Validation Loss: 1.5827


                                                                                         

Epoch[24] Train Loss: 1.7886   Validation Loss: 1.7142


                                                                                         

Epoch[25] Train Loss: 1.8034   Validation Loss: 1.5379


                                                                                         

Epoch[26] Train Loss: 1.8099   Validation Loss: 1.8117


                                                                                         

Epoch[27] Train Loss: 1.8666   Validation Loss: 1.6659


                                                                                         

Epoch[28] Train Loss: 1.8699   Validation Loss: 1.9240


                                                                                         

Epoch[29] Train Loss: 1.8746   Validation Loss: 2.7031


                                                                                         

Epoch[30] Train Loss: 1.9130   Validation Loss: 2.3359


                                                                                         

Epoch[31] Train Loss: 1.9348   Validation Loss: 2.0792


                                                                                         

Epoch[32] Train Loss: 1.9131   Validation Loss: 2.3901


                                                                                         

Epoch[33] Train Loss: 1.9250   Validation Loss: 1.5353


                                                                                         

Epoch[34] Train Loss: 1.8805   Validation Loss: 2.4832


                                                                                         

Epoch[35] Train Loss: 1.9161   Validation Loss: 1.7501


                                                                                         

Epoch[36] Train Loss: 1.9276   Validation Loss: 2.8910


                                                                                         

Epoch[37] Train Loss: 1.8924   Validation Loss: 3.0267


                                                                                         

Epoch[38] Train Loss: 1.9556   Validation Loss: 4.0761


                                                                                         

Epoch[39] Train Loss: 2.1872   Validation Loss: 1.4421


                                                                                         

Epoch[40] Train Loss: 1.9683   Validation Loss: 1.5456


                                                                                         

Epoch[41] Train Loss: 2.0563   Validation Loss: 2.1322


                                                                                         

Epoch[42] Train Loss: 1.9854   Validation Loss: 1.8613


                                                                                         

Epoch[43] Train Loss: 2.0311   Validation Loss: 1.6347


                                                                                         

Epoch[44] Train Loss: 2.0127   Validation Loss: 2.0249


                                                                                         

Epoch[45] Train Loss: 2.0483   Validation Loss: 2.0986


                                                                                         

Epoch[46] Train Loss: 2.0109   Validation Loss: 3.3596


                                                                                         

Epoch[47] Train Loss: 2.1697   Validation Loss: 2.2070


                                                                                         

Epoch[48] Train Loss: 2.1595   Validation Loss: 1.9456


                                                                                         

Epoch[49] Train Loss: 2.0223   Validation Loss: 1.8156


                                                                                         

Epoch[50] Train Loss: 1.8393   Validation Loss: 2.0235
Test loss: 2.0216

Training class: butterfly
Number of samples: 3001
Train set: 2161
Validation set: 240
Test set: 600


                                                                                        

Epoch[1] Train Loss: 4.4109   Validation Loss: 2.0448


                                                                                        

Epoch[2] Train Loss: 1.9376   Validation Loss: 1.9610


                                                                                        

Epoch[3] Train Loss: 1.8540   Validation Loss: 3.2624


                                                                                        

Epoch[4] Train Loss: 1.9448   Validation Loss: 1.7870


                                                                                        

Epoch[5] Train Loss: 1.8846   Validation Loss: 2.8365


                                                                                        

Epoch[6] Train Loss: 1.8443   Validation Loss: 1.4948


                                                                                        

Epoch[7] Train Loss: 1.8353   Validation Loss: 1.8272


                                                                                        

Epoch[8] Train Loss: 1.8769   Validation Loss: 2.2366


                                                                                         

Epoch[9] Train Loss: 2.1755   Validation Loss: 2.0225


                                                                                         

Epoch[10] Train Loss: 1.8188   Validation Loss: 2.0795


                                                                                         

Epoch[11] Train Loss: 1.9334   Validation Loss: 3.0049


                                                                                         

Epoch[12] Train Loss: 1.9419   Validation Loss: 1.8873


                                                                                         

Epoch[13] Train Loss: 1.9658   Validation Loss: 4.3508


                                                                                         

Epoch[14] Train Loss: 1.9595   Validation Loss: 1.8737


                                                                                         

Epoch[15] Train Loss: 2.3428   Validation Loss: 1.5381


                                                                                         

Epoch[16] Train Loss: 1.5956   Validation Loss: 2.1606


                                                                                         

Epoch[17] Train Loss: 1.7331   Validation Loss: 2.2274


                                                                                         

Epoch[18] Train Loss: 1.8706   Validation Loss: 2.3089


                                                                                         

Epoch[19] Train Loss: 1.8549   Validation Loss: 1.7674


                                                                                         

Epoch[20] Train Loss: 1.8703   Validation Loss: 2.3610


                                                                                         

Epoch[21] Train Loss: 1.9158   Validation Loss: 2.0923


                                                                                         

Epoch[22] Train Loss: 1.9833   Validation Loss: 1.6467


                                                                                         

Epoch[23] Train Loss: 1.9221   Validation Loss: 2.8923


                                                                                         

Epoch[24] Train Loss: 1.9629   Validation Loss: 1.4902


                                                                                         

Epoch[25] Train Loss: 1.9994   Validation Loss: 1.7468


                                                                                         

Epoch[26] Train Loss: 2.0529   Validation Loss: 2.5116


                                                                                         

Epoch[27] Train Loss: 2.0513   Validation Loss: 2.7731


                                                                                         

Epoch[28] Train Loss: 1.9870   Validation Loss: 1.7968


                                                                                         

Epoch[29] Train Loss: 2.0538   Validation Loss: 1.5439


                                                                                         

Epoch[30] Train Loss: 2.1141   Validation Loss: 2.3898


                                                                                         

Epoch[31] Train Loss: 2.0941   Validation Loss: 2.3584


Epoch 32:  51%|█████▏    | 70/136 [01:55<01:48,  1.65s/it, G_loss=1.6980, D_loss=0.0906]