### Training monitoring

In [2]:
import torch
from torchvision import transforms
import gc

from ClassComp.models.vgg import VGG
from ClassComp.models.resnet import ResNet
from ClassComp.models.unet import UNet
from ClassComp.models.vae import VAE, SVMLoss, VAE_conv
from ClassComp.data_utils.loaders import get_mnist, get_dataloader
from ClassComp.experiments.train import train_binary_classifier, train_vae, train_vae_kmeans

epochs = 10 # can go up to  20 but will be very slow
image_size = 64
batch_size = 16
custom_transforms = None

learning_rates = [1e-3]

## Get binary MNIST subsets
train_subset, test_subset = get_mnist()

## Get DataLoaders with additional transformations
train_loader, test_loader = get_dataloader(train_subset, test_subset, transform=custom_transforms, size=image_size, batch_size=batch_size)



def del_model(model):
    del model
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()
    torch.cuda.empty_cache()
    gc.collect()



In [5]:

# Vanilla training

description = "vanilla"

## Train
for i, lr in enumerate(learning_rates):
    print("Training VGG")
    vgg = VGG(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(vgg, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vgg.state_dict(), f"./results/models/vgg_{epochs}_epochs_{lr}_lr_{description}.pth.tar")

    del_model(vgg)

    print("Training Resnet")
    resnet = ResNet(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(resnet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        resnet_path = f"./results/models/resnet_{epochs}_epochs_{lr}_lr_{description}.pth.tar"
        torch.save(resnet.state_dict(), resnet_path)

    del_model(resnet)

    print("Training Unet")
    for j in range(2):
        if j == 0:
            resnet = ResNet(input_img_size=image_size, input_img_c=1)  
            unet_description = description + "_pretrained_resnet"
            resnet.load_state_dict(torch.load(resnet_path))
            resnet.eval()
            resnet.to("cuda")
            unet = UNet(image_size, resnet)
        else:
            unet_description = description + "no_pretraining"
            unet = UNet(image_size)
        
        
        train_binary_classifier(unet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=unet_description)

        if i == 0:
            torch.save(unet.state_dict(), f"./results/models/unet_{epochs}_epochs_{lr}_lr_{unet_description}.pth.tar")

        del_model(unet)

    print("Training VAE")
    # for j in range(2):
    vae = VAE(image_size**2, 32*32, 8, beta=0.05)
        # if j == 0:
        #     resnet = ResNet(input_img_size=image_size, input_img_c=1)  
        #     resnet.load_state_dict(torch.load(resnet_path))
        #     resnet.eval()
        #     resnet.to("cuda")
        #     vae = VAE_conv(image_size, resnet)
        #     vae_description = description + "_pretrained_resnet"
        # else:
        #     vae_description = description + "no_pretraining"
        #     vae = VAE_conv(image_size)

    train_vae(vae, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_generation.pth.tar")

    vae.classification_mode = "SVM"

    # Train VAE classification head
    criterion = SVMLoss()
    for param in vae.parameters():
        param.requires_grad = False
    for name, param in vae.named_parameters():
        if "svm_layer" in name:
            param.requires_grad = True
    train_binary_classifier(vae, train_loader, test_loader, criterion=criterion, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_classification.pth.tar")

    train_vae_kmeans(vae, train_loader, test_loader, "cuda", True, description)
    
    del_model(vae)



Training VAE


  vae.load_state_dict(torch.load(f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_generation.pth.tar"))


0.9446808510638298
Training metrics saved to results/training/VAE_vanilla_kmeans.pkl
Kmeans saved to results/models/VAE_vanilla_kmeans.pkl
Training VAE


  vae.load_state_dict(torch.load(f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_generation.pth.tar"))


FileNotFoundError: [Errno 2] No such file or directory: './results/models/vae_10_epochs_1e-05_lr_vanilla_generation.pth.tar'

In [None]:
# Noisy training

image_size = 64
batch_size = 16

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
    
custom_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    transforms.RandomVerticalFlip(p=0.5),    # 50% chance of vertical flip
    transforms.RandomAffine(
        degrees=30,                          # Random rotation within ±30 degrees
        translate=(0.1, 0.1),                # Random translation up to 10% of image size
        scale=(0.9, 1.1),                    # Random scaling between 90% and 110%
        shear=10                             # Random shear within ±10 degrees
    ),
    transforms.RandomPerspective(
        distortion_scale=0.5,                # Distortion scale for perspective transform
        p=0.5                                # 50% chance of applying
    ),
    transforms.RandomErasing(
        p=0.5,                               # 50% chance of applying
        scale=(0.02, 0.2),                   # Proportion of erased area
        ratio=(0.3, 3.3),                    # Aspect ratio of erased area
        value=0                              # Fill value for erased pixels
    ),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),  # Add Gaussian Blur
    transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),  # Adjust sharpness
    AddGaussianNoise(),
])

learning_rates = [1e-3, 1e-5]
description = "class_imbalance_noisy"

## Get binary MNIST subsets
train_subset, test_subset = get_mnist()

## Get DataLoaders with additional transformations
train_loader, test_loader = get_dataloader(train_subset, test_subset, transform=custom_transforms, size=image_size, batch_size=batch_size, class_imbalance=0.6)

## Train
for i, lr in enumerate(learning_rates):
    print("Training VGG")
    vgg = VGG(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(vgg, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vgg.state_dict(), f"./results/models/vgg_{epochs}_epochs_{lr}_lr_{description}.pth.tar")

    del_model(vgg)

    print("Training Resnet")
    resnet = ResNet(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(resnet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        resnet_path = f"./results/models/resnet_{epochs}_epochs_{lr}_lr_{description}.pth.tar"
        torch.save(resnet.state_dict(), resnet_path)

    del_model(resnet)

    print("Training Unet")
    for j in range(2):
        if j == 0:
            resnet = ResNet(input_img_size=image_size, input_img_c=1)  
            unet_description = description + "_pretrained_resnet"
            resnet.load_state_dict(torch.load(resnet_path))
            resnet.eval()
            resnet.to("cuda")
            unet = UNet(image_size, resnet)
        else:
            unet_description = description + "no_pretraining"
            unet = UNet(image_size)

        train_binary_classifier(unet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=unet_description)

        if i == 0:
            torch.save(unet.state_dict(), f"./results/models/unet_{epochs}_epochs_{lr}_lr_{unet_description}.pth.tar")

        del_model(unet)

    print("Training VAE")
    # for j in range(2):
    vae = VAE(image_size**2, 32*32, 8, beta=0.05)
        

    train_vae(vae, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_generation.pth.tar")

    vae.classification_mode = "SVM"

    # Train VAE classification head
    criterion = SVMLoss()
    for param in vae.parameters():
        param.requires_grad = False
    for name, param in vae.named_parameters():
        if "svm_layer" in name:
            param.requires_grad = True
    train_binary_classifier(vae, train_loader, test_loader, criterion=criterion, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_classification.pth.tar")

    train_vae_kmeans(vae, train_loader, test_loader, "cuda", True, description)


    del_model(vae)


Training VGG


                                                             

Epoch [1/20], Loss: 0.6649, Training Accuracy: 59.14%, Time: 89.74s, Gradient Norm: 5.7491
Validation Accuracy: 53.66%


                                                             

Epoch [2/20], Loss: 0.6199, Training Accuracy: 60.53%, Time: 87.08s, Gradient Norm: 2.5078
Validation Accuracy: 53.66%


                                                             

Epoch [3/20], Loss: 0.6073, Training Accuracy: 61.37%, Time: 83.01s, Gradient Norm: 2.2063
Validation Accuracy: 75.84%


                                                             

Epoch [4/20], Loss: 0.3985, Training Accuracy: 84.05%, Time: 84.22s, Gradient Norm: 3.8266
Validation Accuracy: 85.77%


                                                             

Epoch [5/20], Loss: 0.3268, Training Accuracy: 88.60%, Time: 83.53s, Gradient Norm: 4.2111
Validation Accuracy: 86.38%


                                                             

Epoch [6/20], Loss: 0.3119, Training Accuracy: 89.44%, Time: 83.18s, Gradient Norm: 4.2406
Validation Accuracy: 88.18%


                                                             

Epoch [7/20], Loss: 0.2462, Training Accuracy: 91.60%, Time: 86.55s, Gradient Norm: 3.9763
Validation Accuracy: 92.96%


                                                             

Epoch [8/20], Loss: 0.1969, Training Accuracy: 93.16%, Time: 85.33s, Gradient Norm: 3.4662
Validation Accuracy: 94.52%


                                                             

Epoch [9/20], Loss: 0.2104, Training Accuracy: 93.49%, Time: 83.82s, Gradient Norm: 3.6046
Validation Accuracy: 95.18%


                                                              

Epoch [10/20], Loss: 0.1777, Training Accuracy: 94.17%, Time: 85.19s, Gradient Norm: 3.4295
Validation Accuracy: 80.28%


                                                              

Epoch [11/20], Loss: 0.2749, Training Accuracy: 90.99%, Time: 93.84s, Gradient Norm: 4.8657
Validation Accuracy: 61.51%


                                                              

Epoch [12/20], Loss: 0.4091, Training Accuracy: 83.09%, Time: 100.27s, Gradient Norm: 4.2154
Validation Accuracy: 87.85%


                                                              

Epoch [13/20], Loss: 0.2863, Training Accuracy: 89.51%, Time: 105.11s, Gradient Norm: 3.9052
Validation Accuracy: 91.87%


                                                              

Epoch [14/20], Loss: 0.2310, Training Accuracy: 91.64%, Time: 96.71s, Gradient Norm: 3.3858
Validation Accuracy: 93.33%


                                                              

Epoch [15/20], Loss: 0.1697, Training Accuracy: 94.35%, Time: 94.28s, Gradient Norm: 2.7388
Validation Accuracy: 90.02%


                                                              

Epoch [16/20], Loss: 0.2080, Training Accuracy: 93.17%, Time: 89.48s, Gradient Norm: 3.3580
Validation Accuracy: 89.93%


                                                              

Epoch [17/20], Loss: 0.1906, Training Accuracy: 93.59%, Time: 103.96s, Gradient Norm: 2.9794
Validation Accuracy: 95.18%


                                                              

Epoch [18/20], Loss: 0.1192, Training Accuracy: 95.71%, Time: 87.68s, Gradient Norm: 2.1096
Validation Accuracy: 94.99%


                                                              

Epoch [19/20], Loss: 0.1343, Training Accuracy: 95.81%, Time: 89.50s, Gradient Norm: 2.2052
Validation Accuracy: 96.31%


                                                              

Epoch [20/20], Loss: 0.2179, Training Accuracy: 93.26%, Time: 90.87s, Gradient Norm: 3.7998
Validation Accuracy: 94.37%
Training metrics saved to results/training/VGG_20_epochs_0.001_lr_class_imbalance_noisy.pkl
Training Resnet


                                                             

Epoch [1/20], Loss: 0.5309, Training Accuracy: 83.05%, Time: 95.85s, Gradient Norm: 12.5283
Validation Accuracy: 74.23%


                                                             

Epoch [2/20], Loss: 0.3136, Training Accuracy: 90.39%, Time: 98.31s, Gradient Norm: 8.3593
Validation Accuracy: 87.19%


                                                             

Epoch [3/20], Loss: 0.2607, Training Accuracy: 92.82%, Time: 102.29s, Gradient Norm: 8.4524
Validation Accuracy: 94.23%


                                                             

Epoch [4/20], Loss: 0.1822, Training Accuracy: 94.74%, Time: 103.96s, Gradient Norm: 5.4577
Validation Accuracy: 96.26%


                                                             

Epoch [5/20], Loss: 0.1841, Training Accuracy: 94.17%, Time: 103.06s, Gradient Norm: 5.5417
Validation Accuracy: 95.65%


                                                             

Epoch [6/20], Loss: 0.1253, Training Accuracy: 96.00%, Time: 100.76s, Gradient Norm: 3.2966
Validation Accuracy: 97.12%


                                                             

Epoch [7/20], Loss: 0.1255, Training Accuracy: 96.08%, Time: 92.14s, Gradient Norm: 3.4909
Validation Accuracy: 95.65%


                                                             

Epoch [8/20], Loss: 0.1348, Training Accuracy: 95.71%, Time: 92.25s, Gradient Norm: 3.5541
Validation Accuracy: 96.08%


                                                             

Epoch [9/20], Loss: 0.1427, Training Accuracy: 95.74%, Time: 93.42s, Gradient Norm: 3.7625
Validation Accuracy: 94.04%


                                                              

Epoch [10/20], Loss: 0.1524, Training Accuracy: 95.51%, Time: 92.31s, Gradient Norm: 3.8760
Validation Accuracy: 96.26%


                                                              

Epoch [11/20], Loss: 0.1098, Training Accuracy: 96.45%, Time: 92.90s, Gradient Norm: 2.5021
Validation Accuracy: 96.93%


                                                              

Epoch [12/20], Loss: 0.1235, Training Accuracy: 96.13%, Time: 100.65s, Gradient Norm: 3.0451
Validation Accuracy: 96.69%


                                                              

Epoch [13/20], Loss: 0.1310, Training Accuracy: 95.93%, Time: 100.87s, Gradient Norm: 3.0288
Validation Accuracy: 96.55%


                                                              

Epoch [14/20], Loss: 0.1173, Training Accuracy: 96.26%, Time: 100.49s, Gradient Norm: 2.7827
Validation Accuracy: 96.12%


                                                              

Epoch [15/20], Loss: 0.0921, Training Accuracy: 96.90%, Time: 92.21s, Gradient Norm: 1.9142
Validation Accuracy: 96.60%


                                                              

Epoch [16/20], Loss: 0.1222, Training Accuracy: 96.41%, Time: 91.67s, Gradient Norm: 2.4236
Validation Accuracy: 95.41%


                                                              

Epoch [17/20], Loss: 0.2840, Training Accuracy: 94.07%, Time: 91.92s, Gradient Norm: 6.9772
Validation Accuracy: 94.37%


                                                              

Epoch [18/20], Loss: 0.1146, Training Accuracy: 96.54%, Time: 90.77s, Gradient Norm: 2.5112
Validation Accuracy: 96.69%


                                                              

Epoch [19/20], Loss: 0.1113, Training Accuracy: 96.73%, Time: 90.95s, Gradient Norm: 2.0026
Validation Accuracy: 96.36%


                                                              

Epoch [20/20], Loss: 0.1128, Training Accuracy: 96.51%, Time: 90.81s, Gradient Norm: 2.0080
Validation Accuracy: 96.26%
Training metrics saved to results/training/ResNet_20_epochs_0.001_lr_class_imbalance_noisy.pkl
Training Unet


  resnet.load_state_dict(torch.load(resnet_path))
                                                             

Epoch [1/20], Loss: 0.1082, Training Accuracy: 96.23%, Time: 103.94s, Gradient Norm: 0.4479
Validation Accuracy: 97.16%


                                                             

Epoch [2/20], Loss: 0.0824, Training Accuracy: 97.20%, Time: 106.25s, Gradient Norm: 0.2519
Validation Accuracy: 97.78%


                                                             

Epoch [3/20], Loss: 0.0803, Training Accuracy: 96.99%, Time: 104.03s, Gradient Norm: 0.2065
Validation Accuracy: 97.12%


                                                             

Epoch [4/20], Loss: 0.0711, Training Accuracy: 97.35%, Time: 103.61s, Gradient Norm: 0.1737
Validation Accuracy: 97.49%


                                                             

Epoch [5/20], Loss: 0.0723, Training Accuracy: 97.55%, Time: 104.01s, Gradient Norm: 0.1787
Validation Accuracy: 97.16%


                                                             

Epoch [6/20], Loss: 0.0735, Training Accuracy: 97.37%, Time: 104.42s, Gradient Norm: 0.1613
Validation Accuracy: 97.07%


                                                             

Epoch [7/20], Loss: 0.0785, Training Accuracy: 97.09%, Time: 110.18s, Gradient Norm: 0.1747
Validation Accuracy: 96.74%


                                                             

Epoch [8/20], Loss: 0.0714, Training Accuracy: 97.53%, Time: 114.61s, Gradient Norm: 0.1519
Validation Accuracy: 97.35%


                                                             

Epoch [9/20], Loss: 0.0687, Training Accuracy: 97.43%, Time: 112.94s, Gradient Norm: 0.1414
Validation Accuracy: 97.16%


                                                              

Epoch [10/20], Loss: 0.0732, Training Accuracy: 97.47%, Time: 108.31s, Gradient Norm: 0.1467
Validation Accuracy: 97.78%


                                                              

Epoch [11/20], Loss: 0.0687, Training Accuracy: 97.53%, Time: 107.48s, Gradient Norm: 0.1479
Validation Accuracy: 96.64%


                                                              

Epoch [12/20], Loss: 0.0699, Training Accuracy: 97.43%, Time: 109.34s, Gradient Norm: 0.1351
Validation Accuracy: 97.16%


                                                              

Epoch [13/20], Loss: 0.0718, Training Accuracy: 97.33%, Time: 111.30s, Gradient Norm: 0.1372
Validation Accuracy: 97.49%


                                                              

Epoch [14/20], Loss: 0.0753, Training Accuracy: 97.12%, Time: 113.85s, Gradient Norm: 0.1389
Validation Accuracy: 97.21%


                                                              

Epoch [15/20], Loss: 0.0696, Training Accuracy: 97.51%, Time: 132.13s, Gradient Norm: 0.1411
Validation Accuracy: 97.07%


                                                              

Epoch [16/20], Loss: 0.0719, Training Accuracy: 97.39%, Time: 109.50s, Gradient Norm: 0.1340
Validation Accuracy: 96.78%


                                                              

Epoch [17/20], Loss: 0.0722, Training Accuracy: 97.35%, Time: 110.25s, Gradient Norm: 0.1333
Validation Accuracy: 96.93%


                                                              

Epoch [18/20], Loss: 0.0711, Training Accuracy: 97.54%, Time: 109.77s, Gradient Norm: 0.1266
Validation Accuracy: 97.87%


                                                              

Epoch [19/20], Loss: 0.0677, Training Accuracy: 97.46%, Time: 108.12s, Gradient Norm: 0.1300
Validation Accuracy: 98.06%


                                                              

Epoch [20/20], Loss: 0.0618, Training Accuracy: 97.67%, Time: 108.52s, Gradient Norm: 0.1229
Validation Accuracy: 97.73%
Training metrics saved to results/training/UNet_20_epochs_0.001_lr_class_imbalance_noisy_pretrained_resnet.pkl


                                                             

Epoch [1/20], Loss: 0.2734, Training Accuracy: 88.87%, Time: 124.36s, Gradient Norm: 1.8637
Validation Accuracy: 93.33%


                                                             

Epoch [2/20], Loss: 0.1535, Training Accuracy: 94.47%, Time: 129.00s, Gradient Norm: 1.2756
Validation Accuracy: 95.89%


                                                             

Epoch [3/20], Loss: 0.1148, Training Accuracy: 95.73%, Time: 133.98s, Gradient Norm: 1.0100
Validation Accuracy: 96.69%


                                                             

Epoch [4/20], Loss: 0.0949, Training Accuracy: 96.57%, Time: 134.94s, Gradient Norm: 0.8242
Validation Accuracy: 96.88%


                                                             

Epoch [5/20], Loss: 0.0949, Training Accuracy: 96.57%, Time: 128.36s, Gradient Norm: 0.7947
Validation Accuracy: 96.64%


                                                             

Epoch [6/20], Loss: 0.0845, Training Accuracy: 97.00%, Time: 128.63s, Gradient Norm: 0.6908
Validation Accuracy: 96.31%


                                                             

Epoch [7/20], Loss: 0.0812, Training Accuracy: 96.98%, Time: 125.30s, Gradient Norm: 0.6631
Validation Accuracy: 97.35%


                                                             

Epoch [8/20], Loss: 0.0825, Training Accuracy: 97.04%, Time: 125.40s, Gradient Norm: 0.6958
Validation Accuracy: 96.69%


                                                             

Epoch [9/20], Loss: 0.0807, Training Accuracy: 97.11%, Time: 124.90s, Gradient Norm: 0.6153
Validation Accuracy: 97.59%


                                                              

Epoch [10/20], Loss: 0.0782, Training Accuracy: 97.10%, Time: 124.49s, Gradient Norm: 0.5932
Validation Accuracy: 96.41%


                                                              

Epoch [11/20], Loss: 0.0767, Training Accuracy: 97.24%, Time: 124.54s, Gradient Norm: 0.5543
Validation Accuracy: 97.54%


                                                              

Epoch [12/20], Loss: 0.0759, Training Accuracy: 97.32%, Time: 125.38s, Gradient Norm: 0.5140
Validation Accuracy: 96.31%


                                                              

Epoch [13/20], Loss: 0.0692, Training Accuracy: 97.56%, Time: 124.93s, Gradient Norm: 0.5149
Validation Accuracy: 97.40%


                                                              

Epoch [14/20], Loss: 0.0760, Training Accuracy: 97.13%, Time: 124.95s, Gradient Norm: 0.5189
Validation Accuracy: 97.64%


                                                              

Epoch [15/20], Loss: 0.0760, Training Accuracy: 97.32%, Time: 124.70s, Gradient Norm: 0.4891
Validation Accuracy: 97.87%


                                                              

Epoch [16/20], Loss: 0.0748, Training Accuracy: 97.38%, Time: 123.85s, Gradient Norm: 0.4875
Validation Accuracy: 97.97%


                                                              

Epoch [17/20], Loss: 0.0679, Training Accuracy: 97.43%, Time: 124.80s, Gradient Norm: 0.4717
Validation Accuracy: 97.30%


                                                              

Epoch [18/20], Loss: 0.0637, Training Accuracy: 97.66%, Time: 124.97s, Gradient Norm: 0.4580
Validation Accuracy: 97.26%


                                                              

Epoch [19/20], Loss: 0.0664, Training Accuracy: 97.74%, Time: 125.11s, Gradient Norm: 0.4728
Validation Accuracy: 97.64%


                                                              

Epoch [20/20], Loss: 0.0647, Training Accuracy: 97.56%, Time: 126.29s, Gradient Norm: 0.4206
Validation Accuracy: 97.78%
Training metrics saved to results/training/UNet_20_epochs_0.001_lr_class_imbalance_noisyno_pretraining.pkl
Training VAE


                                                                           

Epoch [1/20], Training Loss: 83631.0679, Reconstruction Loss: 83627.1828, KLD Loss: 77.7028, Time: 61.72s, Gradient Norm: 0.1679
Epoch [1/20], Validation Loss: 83281.8004, Validation Recon: 83281.8004, Validation KLD: 0.0000


                                                                           

Epoch [2/20], Training Loss: 83578.6378, Reconstruction Loss: 83578.6378, KLD Loss: 0.0000, Time: 59.48s, Gradient Norm: 0.0005
Epoch [2/20], Validation Loss: 83465.3584, Validation Recon: 83465.3584, Validation KLD: 0.0000


                                                                           

Epoch [3/20], Training Loss: 83755.8705, Reconstruction Loss: 83755.8705, KLD Loss: 0.0000, Time: 58.91s, Gradient Norm: 0.0003
Epoch [3/20], Validation Loss: 83247.0210, Validation Recon: 83247.0210, Validation KLD: 0.0000


                                                                           

Epoch [4/20], Training Loss: 83445.8572, Reconstruction Loss: 83445.8572, KLD Loss: 0.0000, Time: 60.24s, Gradient Norm: 0.0002
Epoch [4/20], Validation Loss: 83093.0164, Validation Recon: 83093.0164, Validation KLD: 0.0000


                                                                           

Epoch [5/20], Training Loss: 83383.1041, Reconstruction Loss: 83383.1041, KLD Loss: 0.0000, Time: 60.48s, Gradient Norm: 0.0002
Epoch [5/20], Validation Loss: 82618.4500, Validation Recon: 82618.4500, Validation KLD: 0.0000


                                                                           

Epoch [6/20], Training Loss: 83505.7683, Reconstruction Loss: 83505.7683, KLD Loss: 0.0001, Time: 60.78s, Gradient Norm: 0.0301
Epoch [6/20], Validation Loss: 82998.4909, Validation Recon: 82998.4909, Validation KLD: 0.0000


                                                                           

Epoch [7/20], Training Loss: 83733.2273, Reconstruction Loss: 83733.2273, KLD Loss: 0.0000, Time: 60.79s, Gradient Norm: 0.0001
Epoch [7/20], Validation Loss: 83105.0365, Validation Recon: 83105.0365, Validation KLD: -0.0000


                                                                           

Epoch [8/20], Training Loss: 83851.1603, Reconstruction Loss: 83851.1603, KLD Loss: 0.0000, Time: 59.86s, Gradient Norm: 0.0027
Epoch [8/20], Validation Loss: 82766.4815, Validation Recon: 82766.4815, Validation KLD: -0.0000


                                                                           

Epoch [9/20], Training Loss: 83894.4172, Reconstruction Loss: 83894.4172, KLD Loss: 0.0000, Time: 59.14s, Gradient Norm: 0.0001
Epoch [9/20], Validation Loss: 83081.8412, Validation Recon: 83081.8412, Validation KLD: 0.0000


                                                                            

Epoch [10/20], Training Loss: 84041.1913, Reconstruction Loss: 84041.1913, KLD Loss: 0.0000, Time: 59.75s, Gradient Norm: 0.0001
Epoch [10/20], Validation Loss: 83241.5893, Validation Recon: 83241.5893, Validation KLD: 0.0000


                                                                            

Epoch [11/20], Training Loss: 83828.5815, Reconstruction Loss: 83828.5815, KLD Loss: 0.0000, Time: 59.87s, Gradient Norm: 0.0001
Epoch [11/20], Validation Loss: 82981.2312, Validation Recon: 82981.2312, Validation KLD: 0.0000


                                                                            

Epoch [12/20], Training Loss: 83717.6706, Reconstruction Loss: 83717.6706, KLD Loss: 0.0000, Time: 58.09s, Gradient Norm: 0.0000
Epoch [12/20], Validation Loss: 83050.6589, Validation Recon: 83050.6589, Validation KLD: 0.0000


                                                                            

Epoch [13/20], Training Loss: 83961.1526, Reconstruction Loss: 83961.1526, KLD Loss: 0.0000, Time: 57.21s, Gradient Norm: 0.0002
Epoch [13/20], Validation Loss: 82633.8457, Validation Recon: 82633.8457, Validation KLD: 0.0000


                                                                            

Epoch [14/20], Training Loss: 83883.4469, Reconstruction Loss: 83883.4469, KLD Loss: 0.0000, Time: 57.43s, Gradient Norm: 0.0001
Epoch [14/20], Validation Loss: 83463.9195, Validation Recon: 83463.9195, Validation KLD: 0.0000


                                                                            

Epoch [15/20], Training Loss: 83784.4486, Reconstruction Loss: 83784.4486, KLD Loss: 0.0001, Time: 57.36s, Gradient Norm: 0.0010
Epoch [15/20], Validation Loss: 82265.7823, Validation Recon: 82265.7823, Validation KLD: 0.0000


                                                                            

Epoch [16/20], Training Loss: 83915.9795, Reconstruction Loss: 83915.9795, KLD Loss: 0.0000, Time: 56.97s, Gradient Norm: 0.0000
Epoch [16/20], Validation Loss: 83147.8932, Validation Recon: 83147.8932, Validation KLD: 0.0000


                                                                            

Epoch [17/20], Training Loss: 83655.4734, Reconstruction Loss: 83655.4734, KLD Loss: 0.0000, Time: 57.10s, Gradient Norm: 0.0000
Epoch [17/20], Validation Loss: 83312.8489, Validation Recon: 83312.8489, Validation KLD: 0.0000


                                                                            

Epoch [18/20], Training Loss: 83868.9356, Reconstruction Loss: 83868.9356, KLD Loss: 0.0000, Time: 57.39s, Gradient Norm: 0.0000
Epoch [18/20], Validation Loss: 83631.7496, Validation Recon: 83631.7496, Validation KLD: 0.0000


                                                                            

Epoch [19/20], Training Loss: 83990.1337, Reconstruction Loss: 83990.1337, KLD Loss: 0.0000, Time: 56.43s, Gradient Norm: 0.0001
Epoch [19/20], Validation Loss: 82978.3593, Validation Recon: 82978.3593, Validation KLD: -0.0000


                                                                            

Epoch [20/20], Training Loss: 83815.9485, Reconstruction Loss: 83815.9485, KLD Loss: 0.0000, Time: 57.87s, Gradient Norm: 0.0000
Epoch [20/20], Validation Loss: 83324.7233, Validation Recon: 83324.7233, Validation KLD: 0.0000
Training metrics saved to results/training/VAE_20_epochs_0.001_lr_class_imbalance_noisy_generation.pkl


                                                             

Epoch [1/20], Loss: 0.9808, Training Accuracy: 49.25%, Time: 51.92s, Gradient Norm: 0.6506
Validation Accuracy: 49.31%


                                                             

Epoch [2/20], Loss: 0.8945, Training Accuracy: 48.76%, Time: 52.32s, Gradient Norm: 0.6397
Validation Accuracy: 48.18%


                                                             

Epoch [3/20], Loss: 0.8182, Training Accuracy: 46.01%, Time: 52.15s, Gradient Norm: 0.6106
Validation Accuracy: 49.88%


                                                             

Epoch [4/20], Loss: 0.8007, Training Accuracy: 49.51%, Time: 52.05s, Gradient Norm: 0.5811
Validation Accuracy: 49.79%


                                                             

Epoch [5/20], Loss: 0.8012, Training Accuracy: 49.93%, Time: 52.76s, Gradient Norm: 0.5871
Validation Accuracy: 48.89%


                                                             

Epoch [6/20], Loss: 0.8012, Training Accuracy: 49.66%, Time: 52.09s, Gradient Norm: 0.5735
Validation Accuracy: 49.13%


                                                             

Epoch [7/20], Loss: 0.8011, Training Accuracy: 50.43%, Time: 51.92s, Gradient Norm: 0.5915
Validation Accuracy: 47.94%


                                                             

Epoch [8/20], Loss: 0.8010, Training Accuracy: 50.32%, Time: 52.23s, Gradient Norm: 0.5940
Validation Accuracy: 51.63%


                                                             

Epoch [9/20], Loss: 0.8006, Training Accuracy: 49.10%, Time: 52.38s, Gradient Norm: 0.5899
Validation Accuracy: 49.08%


                                                              

Epoch [10/20], Loss: 0.8009, Training Accuracy: 49.98%, Time: 52.45s, Gradient Norm: 0.5791
Validation Accuracy: 49.41%


                                                              

Epoch [11/20], Loss: 0.8007, Training Accuracy: 49.87%, Time: 52.43s, Gradient Norm: 0.5882
Validation Accuracy: 52.15%


                                                              

Epoch [12/20], Loss: 0.8010, Training Accuracy: 50.00%, Time: 52.20s, Gradient Norm: 0.6010
Validation Accuracy: 49.98%


                                                              

Epoch [13/20], Loss: 0.8009, Training Accuracy: 49.86%, Time: 52.18s, Gradient Norm: 0.5798
Validation Accuracy: 50.64%


                                                              

Epoch [14/20], Loss: 0.8009, Training Accuracy: 50.23%, Time: 52.21s, Gradient Norm: 0.5808
Validation Accuracy: 51.77%


                                                              

Epoch [15/20], Loss: 0.8007, Training Accuracy: 49.06%, Time: 53.22s, Gradient Norm: 0.5832
Validation Accuracy: 52.20%


                                                              

Epoch [16/20], Loss: 0.8011, Training Accuracy: 50.53%, Time: 52.35s, Gradient Norm: 0.5871
Validation Accuracy: 50.07%


                                                              

Epoch [17/20], Loss: 0.8010, Training Accuracy: 51.03%, Time: 52.21s, Gradient Norm: 0.5937
Validation Accuracy: 50.87%


                                                              

Epoch [18/20], Loss: 0.8011, Training Accuracy: 50.23%, Time: 52.04s, Gradient Norm: 0.5831
Validation Accuracy: 50.69%


                                                              

Epoch [19/20], Loss: 0.8013, Training Accuracy: 50.93%, Time: 52.37s, Gradient Norm: 0.5862
Validation Accuracy: 50.92%


                                                              

Epoch [20/20], Loss: 0.8011, Training Accuracy: 49.25%, Time: 52.96s, Gradient Norm: 0.5976
Validation Accuracy: 50.87%
Training metrics saved to results/training/VAE_20_epochs_0.001_lr_class_imbalance_noisy.pkl
0.5044917257683215
Training metrics saved to results/training/VAE_class_imbalance_noisy_kmeans.pkl
Kmeans saved to results/models/VAE_class_imbalance_noisy_kmeans.pkl
Training VGG


                                                             

Epoch [1/20], Loss: 0.3841, Training Accuracy: 81.68%, Time: 93.16s, Gradient Norm: 161.8096
Validation Accuracy: 91.58%


                                                             

Epoch [2/20], Loss: 0.2036, Training Accuracy: 91.58%, Time: 92.94s, Gradient Norm: 103.8696
Validation Accuracy: 93.85%


                                                             

Epoch [3/20], Loss: 0.1469, Training Accuracy: 94.28%, Time: 91.38s, Gradient Norm: 68.9292
Validation Accuracy: 94.47%


                                                             

Epoch [4/20], Loss: 0.1186, Training Accuracy: 95.56%, Time: 91.11s, Gradient Norm: 56.8061
Validation Accuracy: 95.56%


                                                             

Epoch [5/20], Loss: 0.1043, Training Accuracy: 96.23%, Time: 94.61s, Gradient Norm: 44.8157
Validation Accuracy: 95.79%


                                                             

Epoch [6/20], Loss: 0.0944, Training Accuracy: 96.60%, Time: 92.90s, Gradient Norm: 36.9224
Validation Accuracy: 96.45%


                                                             

Epoch [7/20], Loss: 0.0899, Training Accuracy: 96.69%, Time: 91.55s, Gradient Norm: 33.5884
Validation Accuracy: 96.97%


                                                             

Epoch [8/20], Loss: 0.0859, Training Accuracy: 96.90%, Time: 91.09s, Gradient Norm: 29.4179
Validation Accuracy: 96.36%


                                                             

Epoch [9/20], Loss: 0.0810, Training Accuracy: 96.96%, Time: 95.63s, Gradient Norm: 26.4634
Validation Accuracy: 96.69%


                                                              

Epoch [10/20], Loss: 0.0817, Training Accuracy: 96.88%, Time: 95.22s, Gradient Norm: 23.5085
Validation Accuracy: 97.12%


                                                              

Epoch [11/20], Loss: 0.0766, Training Accuracy: 97.01%, Time: 94.47s, Gradient Norm: 20.2229
Validation Accuracy: 95.22%


                                                              

Epoch [12/20], Loss: 0.0800, Training Accuracy: 97.04%, Time: 96.85s, Gradient Norm: 20.6136
Validation Accuracy: 97.68%


                                                              

Epoch [13/20], Loss: 0.0727, Training Accuracy: 97.20%, Time: 97.30s, Gradient Norm: 19.5566
Validation Accuracy: 97.87%


                                                              

Epoch [14/20], Loss: 0.0733, Training Accuracy: 97.44%, Time: 95.19s, Gradient Norm: 16.3850
Validation Accuracy: 97.07%


                                                              

Epoch [15/20], Loss: 0.0687, Training Accuracy: 97.56%, Time: 97.25s, Gradient Norm: 15.0484
Validation Accuracy: 97.59%


                                                              

Epoch [16/20], Loss: 0.0764, Training Accuracy: 97.15%, Time: 96.77s, Gradient Norm: 14.8466
Validation Accuracy: 97.54%


                                                              

Epoch [17/20], Loss: 0.0676, Training Accuracy: 97.45%, Time: 96.34s, Gradient Norm: 14.6200
Validation Accuracy: 96.64%


                                                              

Epoch [18/20], Loss: 0.0725, Training Accuracy: 97.43%, Time: 97.57s, Gradient Norm: 13.5339
Validation Accuracy: 97.54%


                                                              

Epoch [19/20], Loss: 0.0684, Training Accuracy: 97.54%, Time: 95.90s, Gradient Norm: 12.9117
Validation Accuracy: 97.30%


                                                              

Epoch [20/20], Loss: 0.0709, Training Accuracy: 97.40%, Time: 96.34s, Gradient Norm: 12.7207
Validation Accuracy: 96.78%
Training metrics saved to results/training/VGG_20_epochs_1e-05_lr_class_imbalance_noisy.pkl
Training Resnet


                                                             

Epoch [1/20], Loss: 0.3668, Training Accuracy: 82.85%, Time: 99.09s, Gradient Norm: 108.7123
Validation Accuracy: 89.36%


                                                             

Epoch [2/20], Loss: 0.1982, Training Accuracy: 91.99%, Time: 98.28s, Gradient Norm: 82.1346
Validation Accuracy: 93.33%


                                                             

Epoch [3/20], Loss: 0.1408, Training Accuracy: 94.53%, Time: 97.30s, Gradient Norm: 59.4076
Validation Accuracy: 95.60%


                                                             

Epoch [4/20], Loss: 0.1159, Training Accuracy: 95.44%, Time: 93.89s, Gradient Norm: 46.8915
Validation Accuracy: 94.47%


                                                             

Epoch [5/20], Loss: 0.1121, Training Accuracy: 95.69%, Time: 91.87s, Gradient Norm: 41.2491
Validation Accuracy: 95.65%


                                                             

Epoch [6/20], Loss: 0.0983, Training Accuracy: 96.35%, Time: 91.98s, Gradient Norm: 34.5237
Validation Accuracy: 96.88%


                                                             

Epoch [7/20], Loss: 0.0941, Training Accuracy: 96.57%, Time: 92.37s, Gradient Norm: 30.2010
Validation Accuracy: 96.93%


                                                             

Epoch [8/20], Loss: 0.0904, Training Accuracy: 96.79%, Time: 91.47s, Gradient Norm: 27.8786
Validation Accuracy: 97.16%


                                                             

Epoch [9/20], Loss: 0.0839, Training Accuracy: 96.94%, Time: 103.66s, Gradient Norm: 25.4289
Validation Accuracy: 96.97%


                                                              

Epoch [10/20], Loss: 0.0835, Training Accuracy: 96.72%, Time: 105.06s, Gradient Norm: 24.3595
Validation Accuracy: 97.07%


                                                              

Epoch [11/20], Loss: 0.0827, Training Accuracy: 96.91%, Time: 101.39s, Gradient Norm: 20.8874
Validation Accuracy: 97.26%


                                                              

Epoch [12/20], Loss: 0.0810, Training Accuracy: 97.07%, Time: 98.23s, Gradient Norm: 20.3164
Validation Accuracy: 97.07%


                                                              

Epoch [13/20], Loss: 0.0777, Training Accuracy: 97.09%, Time: 94.38s, Gradient Norm: 19.8237
Validation Accuracy: 96.60%


                                                              

Epoch [14/20], Loss: 0.0782, Training Accuracy: 96.92%, Time: 94.42s, Gradient Norm: 19.2058
Validation Accuracy: 96.88%


                                                              

Epoch [15/20], Loss: 0.0734, Training Accuracy: 97.32%, Time: 113.84s, Gradient Norm: 16.7765
Validation Accuracy: 96.74%


                                                              

Epoch [16/20], Loss: 0.0685, Training Accuracy: 97.47%, Time: 104.19s, Gradient Norm: 15.9310
Validation Accuracy: 97.59%


                                                              

Epoch [17/20], Loss: 0.0763, Training Accuracy: 97.26%, Time: 120.45s, Gradient Norm: 16.6162
Validation Accuracy: 97.21%


                                                              

Epoch [18/20], Loss: 0.0740, Training Accuracy: 97.38%, Time: 125.65s, Gradient Norm: 16.0504
Validation Accuracy: 96.88%


                                                              

Epoch [19/20], Loss: 0.0760, Training Accuracy: 97.13%, Time: 94.87s, Gradient Norm: 14.8933
Validation Accuracy: 97.35%


                                                              

Epoch [20/20], Loss: 0.0686, Training Accuracy: 97.28%, Time: 94.45s, Gradient Norm: 14.2582
Validation Accuracy: 97.21%
Training metrics saved to results/training/ResNet_20_epochs_1e-05_lr_class_imbalance_noisy.pkl
Training Unet


  resnet.load_state_dict(torch.load(resnet_path))
                                                             

Epoch [1/20], Loss: 0.2507, Training Accuracy: 93.11%, Time: 114.15s, Gradient Norm: 2.6913
Validation Accuracy: 97.02%


                                                             

Epoch [2/20], Loss: 0.1383, Training Accuracy: 97.19%, Time: 115.24s, Gradient Norm: 2.5459
Validation Accuracy: 97.49%


                                                             

Epoch [3/20], Loss: 0.1227, Training Accuracy: 97.17%, Time: 117.26s, Gradient Norm: 2.4416
Validation Accuracy: 97.83%


                                                             

Epoch [4/20], Loss: 0.1082, Training Accuracy: 97.13%, Time: 109.53s, Gradient Norm: 2.3855
Validation Accuracy: 97.35%


                                                             

Epoch [5/20], Loss: 0.0987, Training Accuracy: 97.32%, Time: 119.39s, Gradient Norm: 2.3629
Validation Accuracy: 97.40%


                                                             

Epoch [6/20], Loss: 0.0907, Training Accuracy: 97.36%, Time: 113.60s, Gradient Norm: 2.2431
Validation Accuracy: 97.30%


                                                             

Epoch [7/20], Loss: 0.0882, Training Accuracy: 97.35%, Time: 112.22s, Gradient Norm: 2.1098
Validation Accuracy: 97.35%


                                                             

Epoch [8/20], Loss: 0.0838, Training Accuracy: 97.42%, Time: 113.43s, Gradient Norm: 2.0409
Validation Accuracy: 97.02%


                                                             

Epoch [9/20], Loss: 0.0816, Training Accuracy: 97.41%, Time: 110.40s, Gradient Norm: 2.0479
Validation Accuracy: 97.02%


                                                              

Epoch [10/20], Loss: 0.0799, Training Accuracy: 97.29%, Time: 112.84s, Gradient Norm: 2.0046
Validation Accuracy: 96.74%


                                                              

Epoch [11/20], Loss: 0.0793, Training Accuracy: 97.39%, Time: 115.42s, Gradient Norm: 1.9113
Validation Accuracy: 97.73%


                                                              

Epoch [12/20], Loss: 0.0731, Training Accuracy: 97.54%, Time: 115.69s, Gradient Norm: 1.9122
Validation Accuracy: 97.30%


                                                              

Epoch [13/20], Loss: 0.0706, Training Accuracy: 97.65%, Time: 116.19s, Gradient Norm: 1.9363
Validation Accuracy: 97.40%


                                                              

Epoch [14/20], Loss: 0.0740, Training Accuracy: 97.39%, Time: 115.14s, Gradient Norm: 1.9337
Validation Accuracy: 97.64%


                                                              

Epoch [15/20], Loss: 0.0733, Training Accuracy: 97.50%, Time: 114.44s, Gradient Norm: 1.8888
Validation Accuracy: 97.02%


                                                              

Epoch [16/20], Loss: 0.0700, Training Accuracy: 97.58%, Time: 122.20s, Gradient Norm: 1.8608
Validation Accuracy: 97.92%


                                                              

Epoch [17/20], Loss: 0.0751, Training Accuracy: 97.30%, Time: 114.46s, Gradient Norm: 1.8717
Validation Accuracy: 97.64%


                                                              

Epoch [18/20], Loss: 0.0685, Training Accuracy: 97.55%, Time: 115.56s, Gradient Norm: 1.7799
Validation Accuracy: 97.68%


                                                              

Epoch [19/20], Loss: 0.0637, Training Accuracy: 97.82%, Time: 117.39s, Gradient Norm: 1.6720
Validation Accuracy: 97.83%


                                                              

Epoch [20/20], Loss: 0.0684, Training Accuracy: 97.61%, Time: 115.40s, Gradient Norm: 1.8024
Validation Accuracy: 97.54%
Training metrics saved to results/training/UNet_20_epochs_1e-05_lr_class_imbalance_noisy_pretrained_resnet.pkl


                                                             

Epoch [1/20], Loss: 0.3739, Training Accuracy: 83.56%, Time: 133.03s, Gradient Norm: 21.2007
Validation Accuracy: 92.58%


                                                             

Epoch [2/20], Loss: 0.2069, Training Accuracy: 93.62%, Time: 131.03s, Gradient Norm: 26.0657
Validation Accuracy: 95.04%


                                                             

Epoch [3/20], Loss: 0.1673, Training Accuracy: 95.19%, Time: 133.88s, Gradient Norm: 21.9126
Validation Accuracy: 95.79%


                                                             

Epoch [4/20], Loss: 0.1424, Training Accuracy: 95.84%, Time: 134.71s, Gradient Norm: 20.9509
Validation Accuracy: 96.78%


                                                             

Epoch [5/20], Loss: 0.1229, Training Accuracy: 96.61%, Time: 132.40s, Gradient Norm: 18.5986
Validation Accuracy: 96.64%


                                                             

Epoch [6/20], Loss: 0.1176, Training Accuracy: 96.48%, Time: 134.74s, Gradient Norm: 18.8317
Validation Accuracy: 96.64%


                                                             

Epoch [7/20], Loss: 0.1076, Training Accuracy: 96.75%, Time: 133.18s, Gradient Norm: 17.4388
Validation Accuracy: 97.02%


                                                             

Epoch [8/20], Loss: 0.0998, Training Accuracy: 97.05%, Time: 140.35s, Gradient Norm: 16.0392
Validation Accuracy: 97.30%


                                                             

Epoch [9/20], Loss: 0.0942, Training Accuracy: 96.97%, Time: 164.47s, Gradient Norm: 15.8288
Validation Accuracy: 97.45%


                                                              

Epoch [10/20], Loss: 0.0896, Training Accuracy: 97.24%, Time: 139.49s, Gradient Norm: 14.9794
Validation Accuracy: 96.88%


                                                              

Epoch [11/20], Loss: 0.0866, Training Accuracy: 97.13%, Time: 162.88s, Gradient Norm: 13.6983
Validation Accuracy: 96.97%


                                                              

Epoch [12/20], Loss: 0.0842, Training Accuracy: 97.17%, Time: 144.72s, Gradient Norm: 13.7305
Validation Accuracy: 97.54%


                                                              

Epoch [13/20], Loss: 0.0843, Training Accuracy: 97.20%, Time: 146.03s, Gradient Norm: 13.4815
Validation Accuracy: 98.11%


                                                              

Epoch [14/20], Loss: 0.0815, Training Accuracy: 97.13%, Time: 172.56s, Gradient Norm: 13.0024
Validation Accuracy: 97.78%


                                                              

Epoch [15/20], Loss: 0.0786, Training Accuracy: 97.32%, Time: 174.06s, Gradient Norm: 12.2561
Validation Accuracy: 97.54%


                                                              

Epoch [16/20], Loss: 0.0791, Training Accuracy: 97.40%, Time: 154.18s, Gradient Norm: 12.4894
Validation Accuracy: 97.54%


                                                              

Epoch [17/20], Loss: 0.0750, Training Accuracy: 97.32%, Time: 141.86s, Gradient Norm: 11.3617
Validation Accuracy: 98.01%


                                                              

Epoch [18/20], Loss: 0.0720, Training Accuracy: 97.50%, Time: 142.93s, Gradient Norm: 10.7597
Validation Accuracy: 97.45%


                                                              

Epoch [19/20], Loss: 0.0785, Training Accuracy: 97.28%, Time: 143.96s, Gradient Norm: 11.4791
Validation Accuracy: 97.02%


                                                              

Epoch [20/20], Loss: 0.0727, Training Accuracy: 97.45%, Time: 143.26s, Gradient Norm: 10.2971
Validation Accuracy: 97.78%
Training metrics saved to results/training/UNet_20_epochs_1e-05_lr_class_imbalance_noisyno_pretraining.pkl
Training VAE


                                                                           

Epoch [1/20], Training Loss: 88447.8515, Reconstruction Loss: 88280.2917, KLD Loss: 3351.1963, Time: 59.45s, Gradient Norm: 2.2954
Epoch [1/20], Validation Loss: 82040.1045, Validation Recon: 81997.4111, Validation KLD: 853.8654


                                                                           

Epoch [2/20], Training Loss: 83508.5325, Reconstruction Loss: 83470.2077, KLD Loss: 766.4951, Time: 59.28s, Gradient Norm: 1.9511
Epoch [2/20], Validation Loss: 82630.6642, Validation Recon: 82603.6717, Validation KLD: 539.8466


                                                                           

Epoch [3/20], Training Loss: 83143.5742, Reconstruction Loss: 83120.0751, KLD Loss: 469.9808, Time: 59.49s, Gradient Norm: 1.8886
Epoch [3/20], Validation Loss: 82507.8525, Validation Recon: 82489.6630, Validation KLD: 363.7891


                                                                           

Epoch [4/20], Training Loss: 83246.8766, Reconstruction Loss: 83230.0515, KLD Loss: 336.5020, Time: 57.32s, Gradient Norm: 1.8397
Epoch [4/20], Validation Loss: 82296.7656, Validation Recon: 82282.8703, Validation KLD: 277.9059


                                                                           

Epoch [5/20], Training Loss: 83219.2714, Reconstruction Loss: 83205.9863, KLD Loss: 265.7024, Time: 62.05s, Gradient Norm: 1.8064
Epoch [5/20], Validation Loss: 82219.2155, Validation Recon: 82207.2971, Validation KLD: 238.3692


                                                                           

Epoch [6/20], Training Loss: 83230.3584, Reconstruction Loss: 83219.0094, KLD Loss: 226.9802, Time: 62.69s, Gradient Norm: 1.8036
Epoch [6/20], Validation Loss: 82979.5124, Validation Recon: 82968.7202, Validation KLD: 215.8383


                                                                           

Epoch [7/20], Training Loss: 83249.1310, Reconstruction Loss: 83238.5941, KLD Loss: 210.7374, Time: 65.51s, Gradient Norm: 1.8111
Epoch [7/20], Validation Loss: 83053.0210, Validation Recon: 83042.7934, Validation KLD: 204.5544


                                                                           

Epoch [8/20], Training Loss: 83085.7513, Reconstruction Loss: 83075.8898, KLD Loss: 197.2310, Time: 67.60s, Gradient Norm: 1.8245
Epoch [8/20], Validation Loss: 82599.7093, Validation Recon: 82590.0178, Validation KLD: 193.8332


                                                                           

Epoch [9/20], Training Loss: 83132.5000, Reconstruction Loss: 83122.9343, KLD Loss: 191.3132, Time: 68.06s, Gradient Norm: 1.8517
Epoch [9/20], Validation Loss: 83103.5459, Validation Recon: 83093.6173, Validation KLD: 198.5691


                                                                            

Epoch [10/20], Training Loss: 82995.4995, Reconstruction Loss: 82985.6224, KLD Loss: 197.5405, Time: 64.99s, Gradient Norm: 1.8737
Epoch [10/20], Validation Loss: 82519.8883, Validation Recon: 82508.5481, Validation KLD: 226.8040


                                                                            

Epoch [11/20], Training Loss: 83144.0912, Reconstruction Loss: 83131.7207, KLD Loss: 247.4078, Time: 69.06s, Gradient Norm: 1.8505
Epoch [11/20], Validation Loss: 82831.5099, Validation Recon: 82817.4918, Validation KLD: 280.3606


                                                                            

Epoch [12/20], Training Loss: 82875.2660, Reconstruction Loss: 82862.4270, KLD Loss: 256.7798, Time: 76.42s, Gradient Norm: 1.8107
Epoch [12/20], Validation Loss: 83138.7354, Validation Recon: 83125.6751, Validation KLD: 261.2104


                                                                            

Epoch [13/20], Training Loss: 82998.5128, Reconstruction Loss: 82986.4388, KLD Loss: 241.4774, Time: 60.31s, Gradient Norm: 1.7821
Epoch [13/20], Validation Loss: 82796.3623, Validation Recon: 82783.9105, Validation KLD: 249.0380


                                                                            

Epoch [14/20], Training Loss: 82937.6301, Reconstruction Loss: 82925.9759, KLD Loss: 233.0827, Time: 62.48s, Gradient Norm: 1.7793
Epoch [14/20], Validation Loss: 82532.7755, Validation Recon: 82521.0174, Validation KLD: 235.1644


                                                                            

Epoch [15/20], Training Loss: 83198.0118, Reconstruction Loss: 83186.4945, KLD Loss: 230.3454, Time: 57.86s, Gradient Norm: 1.7730
Epoch [15/20], Validation Loss: 82907.7215, Validation Recon: 82895.8467, Validation KLD: 237.4897


                                                                            

Epoch [16/20], Training Loss: 83195.3018, Reconstruction Loss: 83183.4754, KLD Loss: 236.5310, Time: 57.30s, Gradient Norm: 1.7902
Epoch [16/20], Validation Loss: 82519.0345, Validation Recon: 82506.2910, Validation KLD: 254.8758


                                                                            

Epoch [17/20], Training Loss: 83109.2972, Reconstruction Loss: 83096.5924, KLD Loss: 254.0959, Time: 58.82s, Gradient Norm: 1.7921
Epoch [17/20], Validation Loss: 81784.7388, Validation Recon: 81771.5207, Validation KLD: 264.3620


                                                                            

Epoch [18/20], Training Loss: 83160.0296, Reconstruction Loss: 83146.5670, KLD Loss: 269.2522, Time: 58.84s, Gradient Norm: 1.7867
Epoch [18/20], Validation Loss: 81857.0595, Validation Recon: 81842.4823, Validation KLD: 291.5483


                                                                            

Epoch [19/20], Training Loss: 82840.0847, Reconstruction Loss: 82825.5383, KLD Loss: 290.9265, Time: 67.55s, Gradient Norm: 1.7897
Epoch [19/20], Validation Loss: 82391.1970, Validation Recon: 82375.5384, Validation KLD: 313.1713


                                                                            

Epoch [20/20], Training Loss: 82841.2742, Reconstruction Loss: 82825.4222, KLD Loss: 317.0382, Time: 60.95s, Gradient Norm: 1.7720
Epoch [20/20], Validation Loss: 82487.4265, Validation Recon: 82470.5245, Validation KLD: 338.0410
Training metrics saved to results/training/VAE_20_epochs_1e-05_lr_class_imbalance_noisy_generation.pkl


                                                             

Epoch [1/20], Loss: 0.7858, Training Accuracy: 45.30%, Time: 62.76s, Gradient Norm: 2.7854
Validation Accuracy: 43.12%


                                                             

Epoch [2/20], Loss: 0.7676, Training Accuracy: 44.77%, Time: 62.19s, Gradient Norm: 2.7669
Validation Accuracy: 43.12%


                                                             

Epoch [3/20], Loss: 0.7581, Training Accuracy: 44.39%, Time: 63.02s, Gradient Norm: 2.7367
Validation Accuracy: 42.13%


                                                             

Epoch [4/20], Loss: 0.7470, Training Accuracy: 44.97%, Time: 63.96s, Gradient Norm: 2.7325
Validation Accuracy: 42.60%


                                                             

Epoch [5/20], Loss: 0.7299, Training Accuracy: 44.96%, Time: 60.17s, Gradient Norm: 2.6906
Validation Accuracy: 44.02%


                                                             

Epoch [6/20], Loss: 0.7268, Training Accuracy: 45.54%, Time: 58.96s, Gradient Norm: 2.6624
Validation Accuracy: 42.27%


                                                             

Epoch [7/20], Loss: 0.7144, Training Accuracy: 44.71%, Time: 57.48s, Gradient Norm: 2.6680
Validation Accuracy: 41.18%


                                                             

Epoch [8/20], Loss: 0.7032, Training Accuracy: 45.38%, Time: 56.56s, Gradient Norm: 2.5982
Validation Accuracy: 41.32%


                                                             

Epoch [9/20], Loss: 0.6970, Training Accuracy: 46.51%, Time: 55.51s, Gradient Norm: 2.5815
Validation Accuracy: 40.80%


                                                              

Epoch [10/20], Loss: 0.6920, Training Accuracy: 45.91%, Time: 55.72s, Gradient Norm: 2.5583
Validation Accuracy: 42.74%


                                                              

Epoch [11/20], Loss: 0.6865, Training Accuracy: 46.44%, Time: 55.67s, Gradient Norm: 2.5236
Validation Accuracy: 42.88%


                                                              

Epoch [12/20], Loss: 0.6784, Training Accuracy: 46.81%, Time: 57.06s, Gradient Norm: 2.5166
Validation Accuracy: 43.12%


                                                              

Epoch [13/20], Loss: 0.6679, Training Accuracy: 46.33%, Time: 56.47s, Gradient Norm: 2.4744
Validation Accuracy: 43.45%


                                                              

Epoch [14/20], Loss: 0.6666, Training Accuracy: 46.40%, Time: 55.50s, Gradient Norm: 2.4770
Validation Accuracy: 44.16%


                                                              

Epoch [15/20], Loss: 0.6625, Training Accuracy: 47.17%, Time: 55.99s, Gradient Norm: 2.4718
Validation Accuracy: 42.93%


                                                              

Epoch [16/20], Loss: 0.6622, Training Accuracy: 46.62%, Time: 54.13s, Gradient Norm: 2.4322
Validation Accuracy: 44.02%


                                                              

Epoch [17/20], Loss: 0.6577, Training Accuracy: 47.84%, Time: 57.51s, Gradient Norm: 2.4355
Validation Accuracy: 43.50%


                                                              

Epoch [18/20], Loss: 0.6484, Training Accuracy: 47.82%, Time: 60.72s, Gradient Norm: 2.4378
Validation Accuracy: 44.54%


                                                              

Epoch [19/20], Loss: 0.6470, Training Accuracy: 47.63%, Time: 58.83s, Gradient Norm: 2.4177
Validation Accuracy: 43.74%


                                                              

Epoch [20/20], Loss: 0.6448, Training Accuracy: 47.73%, Time: 59.89s, Gradient Norm: 2.4112
Validation Accuracy: 42.36%
Training metrics saved to results/training/VAE_20_epochs_1e-05_lr_class_imbalance_noisy.pkl
0.45673758865248226
Training metrics saved to results/training/VAE_class_imbalance_noisy_kmeans.pkl
Kmeans saved to results/models/VAE_class_imbalance_noisy_kmeans.pkl


In [15]:
# Strong class imbalance

image_size = 64
batch_size = 16

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
    
custom_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    transforms.RandomVerticalFlip(p=0.5),    # 50% chance of vertical flip
    transforms.RandomAffine(
        degrees=30,                          # Random rotation within ±30 degrees
        translate=(0.1, 0.1),                # Random translation up to 10% of image size
        scale=(0.9, 1.1),                    # Random scaling between 90% and 110%
        shear=10                             # Random shear within ±10 degrees
    ),
    transforms.RandomPerspective(
        distortion_scale=0.5,                # Distortion scale for perspective transform
        p=0.5                                # 50% chance of applying
    ),
    transforms.RandomErasing(
        p=0.5,                               # 50% chance of applying
        scale=(0.02, 0.2),                   # Proportion of erased area
        ratio=(0.3, 3.3),                    # Aspect ratio of erased area
        value=0                              # Fill value for erased pixels
    ),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),  # Add Gaussian Blur
    transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),  # Adjust sharpness
    AddGaussianNoise(),
])

learning_rates = [1e-3, 1e-5]
description = "class_imbalance"

## Get binary MNIST subsets
train_subset, test_subset = get_mnist()

## Get DataLoaders with additional transformations
train_loader, test_loader = get_dataloader(train_subset, test_subset, transform=None, size=image_size, batch_size=batch_size, class_imbalance=0.9)

## Train
for i, lr in enumerate(learning_rates):
    print("Training VGG")
    vgg = VGG(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(vgg, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vgg.state_dict(), f"./results/models/vgg_{epochs}_epochs_{lr}_lr_{description}.pth.tar")

    del_model(vgg)

    print("Training Resnet")
    resnet = ResNet(input_img_size=image_size, input_img_c=1)
    train_binary_classifier(resnet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        resnet_path = f"./results/models/resnet_{epochs}_epochs_{lr}_lr_{description}.pth.tar"
        torch.save(resnet.state_dict(), resnet_path)

    del_model(resnet)

    print("Training Unet")
    for j in range(2):
        if j == 0:
            resnet = ResNet(input_img_size=image_size, input_img_c=1)  
            unet_description = description + "_pretrained_resnet"
            resnet.load_state_dict(torch.load(resnet_path))
            resnet.eval()
            resnet.to("cuda")
            unet = UNet(image_size, resnet)
        else:
            unet_description = description + "no_pretraining"
            unet = UNet(image_size)

        train_binary_classifier(unet, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=unet_description)

        if i == 0:
            torch.save(unet.state_dict(), f"./results/models/unet_{epochs}_epochs_{lr}_lr_{unet_description}.pth.tar")

        del_model(unet)

    print("Training VAE")
    # for j in range(2):
    vae = VAE(image_size**2, 32*32, 8, beta=0.05)
        

    train_vae(vae, train_loader, test_loader, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_generation.pth.tar")

    vae.classification_mode = "SVM"

    # Train VAE classification head
    criterion = SVMLoss()
    for param in vae.parameters():
        param.requires_grad = False
    for name, param in vae.named_parameters():
        if "svm_layer" in name:
            param.requires_grad = True
    train_binary_classifier(vae, train_loader, test_loader, criterion=criterion, epochs=epochs, learning_rate=lr, device="cuda", save_results=True, description=description)

    if i == 0:
        torch.save(vae.state_dict(), f"./results/models/vae_{epochs}_epochs_{lr}_lr_{description}_classification.pth.tar")

    train_vae_kmeans(vae, train_loader, test_loader, "cuda", True, description)


    del_model(vae)


Training VGG


                                                             

Epoch [1/20], Loss: 0.3953, Training Accuracy: 97.98%, Time: 79.65s, Gradient Norm: 12.6666
Validation Accuracy: 81.09%


                                                             

Epoch [2/20], Loss: 0.1061, Training Accuracy: 98.07%, Time: 79.60s, Gradient Norm: 3.9091
Validation Accuracy: 99.57%


                                                             

Epoch [3/20], Loss: 0.0545, Training Accuracy: 99.01%, Time: 79.50s, Gradient Norm: 6.7300
Validation Accuracy: 99.10%


                                                             

Epoch [4/20], Loss: 0.0300, Training Accuracy: 99.02%, Time: 78.98s, Gradient Norm: 0.9624
Validation Accuracy: 99.72%


                                                             

Epoch [5/20], Loss: 0.0175, Training Accuracy: 99.62%, Time: 79.10s, Gradient Norm: 0.5971
Validation Accuracy: 99.62%


                                                             

Epoch [6/20], Loss: 0.0101, Training Accuracy: 99.76%, Time: 79.09s, Gradient Norm: 0.3703
Validation Accuracy: 99.86%


                                                             

Epoch [7/20], Loss: 0.0072, Training Accuracy: 99.85%, Time: 79.39s, Gradient Norm: 0.2393
Validation Accuracy: 99.95%


                                                             

Epoch [8/20], Loss: 0.0018, Training Accuracy: 99.96%, Time: 79.48s, Gradient Norm: 0.0824
Validation Accuracy: 99.86%


                                                             

Epoch [9/20], Loss: 0.0001, Training Accuracy: 100.00%, Time: 79.69s, Gradient Norm: 0.0119
Validation Accuracy: 99.39%


                                                              

Epoch [10/20], Loss: 0.0150, Training Accuracy: 99.68%, Time: 79.76s, Gradient Norm: 0.5108
Validation Accuracy: 99.43%


                                                              

Epoch [11/20], Loss: 0.0510, Training Accuracy: 99.62%, Time: 79.52s, Gradient Norm: 2.7662
Validation Accuracy: 99.57%


                                                              

Epoch [12/20], Loss: 0.0085, Training Accuracy: 99.79%, Time: 79.54s, Gradient Norm: 0.2983
Validation Accuracy: 99.91%


                                                              

Epoch [13/20], Loss: 0.0041, Training Accuracy: 99.88%, Time: 79.04s, Gradient Norm: 0.1713
Validation Accuracy: 99.95%


                                                              

Epoch [14/20], Loss: 0.0022, Training Accuracy: 99.93%, Time: 65.56s, Gradient Norm: 0.0903
Validation Accuracy: 99.76%


                                                              

Epoch [15/20], Loss: 0.0019, Training Accuracy: 99.97%, Time: 59.35s, Gradient Norm: 0.0620
Validation Accuracy: 99.91%


                                                              

Epoch [16/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 59.06s, Gradient Norm: 0.0020
Validation Accuracy: 99.86%


                                                              

Epoch [17/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 58.67s, Gradient Norm: 0.0008
Validation Accuracy: 99.91%


                                                              

Epoch [18/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 58.76s, Gradient Norm: 0.0002
Validation Accuracy: 99.86%


                                                              

Epoch [19/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 58.69s, Gradient Norm: 0.0000
Validation Accuracy: 99.91%


                                                              

Epoch [20/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 59.06s, Gradient Norm: 0.0000
Validation Accuracy: 99.91%
Training metrics saved to results/training/VGG_20_epochs_0.001_lr_class_imbalance_strong.pkl
Training Resnet


                                                             

Epoch [1/20], Loss: 0.4659, Training Accuracy: 98.97%, Time: 66.15s, Gradient Norm: 16.2530
Validation Accuracy: 99.48%


                                                             

Epoch [2/20], Loss: 1.4657, Training Accuracy: 99.57%, Time: 64.96s, Gradient Norm: 24.7864
Validation Accuracy: 99.86%


                                                             

Epoch [3/20], Loss: 0.5177, Training Accuracy: 99.83%, Time: 64.62s, Gradient Norm: 7.0313
Validation Accuracy: 99.86%


                                                             

Epoch [4/20], Loss: 0.1313, Training Accuracy: 99.91%, Time: 65.38s, Gradient Norm: 2.7246
Validation Accuracy: 99.95%


                                                             

Epoch [5/20], Loss: 0.4842, Training Accuracy: 99.77%, Time: 65.32s, Gradient Norm: 11.0824
Validation Accuracy: 99.81%


                                                             

Epoch [6/20], Loss: 0.4004, Training Accuracy: 99.84%, Time: 65.37s, Gradient Norm: 7.9494
Validation Accuracy: 99.76%


                                                             

Epoch [7/20], Loss: 0.0971, Training Accuracy: 99.93%, Time: 72.71s, Gradient Norm: 1.6994
Validation Accuracy: 99.91%


                                                             

Epoch [8/20], Loss: 0.1252, Training Accuracy: 99.89%, Time: 72.09s, Gradient Norm: 3.5569
Validation Accuracy: 99.86%


                                                             

Epoch [9/20], Loss: 0.0926, Training Accuracy: 99.94%, Time: 70.90s, Gradient Norm: 2.1427
Validation Accuracy: 99.95%


                                                              

Epoch [10/20], Loss: 0.2543, Training Accuracy: 99.91%, Time: 71.71s, Gradient Norm: 5.8470
Validation Accuracy: 99.72%


                                                              

Epoch [11/20], Loss: 0.4036, Training Accuracy: 99.88%, Time: 73.06s, Gradient Norm: 7.0085
Validation Accuracy: 99.72%


                                                              

Epoch [12/20], Loss: 0.0076, Training Accuracy: 99.98%, Time: 66.59s, Gradient Norm: 1.4614
Validation Accuracy: 99.81%


                                                              

Epoch [13/20], Loss: 0.0478, Training Accuracy: 99.98%, Time: 70.53s, Gradient Norm: 0.6704
Validation Accuracy: 99.86%


                                                              

Epoch [14/20], Loss: 0.3204, Training Accuracy: 99.91%, Time: 68.44s, Gradient Norm: 5.7452
Validation Accuracy: 99.91%


                                                              

Epoch [15/20], Loss: 0.1240, Training Accuracy: 99.94%, Time: 69.08s, Gradient Norm: 2.1489
Validation Accuracy: 100.00%


                                                              

Epoch [16/20], Loss: 0.0183, Training Accuracy: 99.97%, Time: 67.62s, Gradient Norm: 0.9534
Validation Accuracy: 99.91%


                                                              

Epoch [17/20], Loss: 0.0000, Training Accuracy: 100.00%, Time: 67.69s, Gradient Norm: 0.0000
Validation Accuracy: 99.86%


                                                              

Epoch [18/20], Loss: 0.1425, Training Accuracy: 99.95%, Time: 67.63s, Gradient Norm: 2.4028
Validation Accuracy: 99.95%


                                                              

KeyboardInterrupt: 