In [None]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import os, pickle, shutil, random, PIL
from PIL import Image

import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader,random_split,Dataset, ConcatDataset ,SubsetRandomSampler
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.transforms import v2

import matplotlib.pyplot as plt
from focal_loss_with_smoothing import FocalLossWithSmoothing
from model import *
# from torchinfo import summary
from training_utils import *

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

DFNAME = '_'
device = torch.device('cuda:0')
criterion1 = FocalLossWithSmoothing(num_classes =2,gamma=2, lb_smooth = 0.1)
criterion2 = nn.MSELoss()

modelname = 'rn50_autoencoder_withlatent'
n_epochs = 100
batch_size = 4

In [None]:
train_dir = '/nfs/cc-filer/home/mpervin/Saud/mass_non_mass_1024/train/'
test_dir = '/nfs/cc-filer/home/mpervin/Saud/mass_non_mass_1024/test/'

# train_dir = '/nfs/cc-filer/home/mpervin/Saud/mias_mass_non_mass/train/'
# test_dir = '/nfs/cc-filer/home/mpervin/Saud/mias_mass_non_mass/test/'

In [None]:
size = (512,512)
train_set_whole = ImageFolder(train_dir,transform = transforms.Compose([
    v2.Resize(size),
    v2.Grayscale(1),
    v2.RandomHorizontalFlip(0.5),
    v2.RandomVerticalFlip(0.5),
    v2.RandomRotation(30),
    v2.ToTensor(),
]))

test_set = ImageFolder(test_dir,transform = transforms.Compose([
    v2.Resize(size),
    v2.Grayscale(1),
    v2.ToTensor(),
    ]))

train_set, valid_set = random_split(train_set_whole,[int(len(train_set_whole)*0.9), int(len(train_set_whole)*0.1)+1],
                                  generator=torch.Generator().manual_seed(0))

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 4)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers = 4)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers = 4)
test_loader_2 = DataLoader(test_set, batch_size=1, shuffle=False)

In [None]:
### Use any one of the following three models. Comment out the ones you don't use or don't run it.

In [None]:
model = AutoencoderWithClassification(num_classes=2).to(device)

In [None]:
modelname = 'MvNM_rn50_autoencoder_gendata'
checkpoint = torch.load('./checkpoint/'+modelname+'model.pth.tar',map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])

modelname = 'proposed'

In [None]:
# summary(model, input_size=(1,1,768,512), col_names =['input_size', 'output_size','num_params','trainable'] )  

In [None]:
num_layers_to_freeze = 33
counter = 0
for name, param in model.named_parameters():
    if counter < num_layers_to_freeze:
        # param.requires_grad = False
        # print(f"Freezing layer {name}")
        counter += 1
    else:
        break

# # Check which layers are frozen
# for name, param in model.named_parameters():
#     if not param.requires_grad:
#         print(f"Layer {name} is frozen")
#     else:
#         print(f"Layer {name} is trainable")

# summary(model, input_size=(1,1,768,512), col_names =['input_size', 'output_size','num_params','trainable'] )  

In [None]:
optim = torch.optim.Adam(model.parameters(),lr=0.0005,weight_decay=1e-4)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=3)

history = {'train_loss': [], 'valid_loss': [],'train_acc':[],'valid_acc':[]}

for epoch in range(n_epochs):

    train_loss, train_acc = train_both(model,train_loader,criterion1,criterion2,optim,device,epoch)
    valid_loss, valid_acc = test_both(model,valid_loader,criterion1,criterion2,optim,modelname,device,epoch)

    scheduler.step(valid_loss)

    history['train_loss'].append(train_loss)
    history['valid_loss'].append(valid_loss)
    history['train_acc'].append(train_acc)
    history['valid_acc'].append(valid_acc)


with open('./storage/' + DFNAME + '.pkl', 'wb') as f:
    pickle.dump(history, f)

In [None]:
### While testing use the same model that was used for training. We are creating a new instance of that model
### with the best performing weights that we stored during training

In [None]:
new_model = AutoencoderWithClassification(num_classes=2).to(device)

In [None]:
checkpoint = torch.load('./checkpoint/'+modelname+'model.pth.tar',map_location=torch.device('cpu'))
new_model.load_state_dict(checkpoint['model'])

best_test_both(new_model,test_loader, criterion1,criterion2,device)
# _, _, y, y_pred = best_test(new_model,test_loader,criterion,optim,device,0)

In [None]:
best_test_both(new_model,train_loader, criterion1,criterion2,device)

In [None]:
best_test_both(new_model,valid_loader, criterion1,criterion2,device)