# Autoencoder + Classification on encoded data

In [None]:
import os
import re
import shutil
from skimage import io
from PIL import Image

from torchvision import transforms
from torch.utils.data import DataLoader
from skimage import io
import torch
import glob
import torch
from torch import nn
from torchvision import models
from torch.optim import Adam
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset

import torchvision
from torchvision import datasets, models, transforms

from itertools import cycle

from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import roc_curve
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import auc
import copy
from torchvision.models import ResNet50_Weights

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import glob
# import staintools
from PIL import ImageOps

## Prepare the label

In [None]:
label = 'Ki67 (%)'
df = pd.read_excel('data_file.xlsx')

In [None]:
df = df[['ID_number', label]]
df = df.dropna(subset=[label])

In [None]:
np.unique(df[label], return_counts=True)

In [None]:
q20 = df[label].quantile(0.20)
q80 = df[label].quantile(0.80)
print(q20)
print(q80)

# Create new dataframe by filtering values below 20th percentile and above 80th percentile
df_extreme = df[(df[label] <= q20) | (df[label] >= q80)].copy()

# Set labels: 0 for values below or equal to 20th percentile, 1 for values above or equal to 80th percentile
df_extreme.loc[df_extreme[label] <= q20, label] = 0
df_extreme.loc[df_extreme[label] >= q80, label] = 1

In [None]:
q20 = df[label].quantile(0.20)
q80 = df[label].quantile(0.80)
q30 = df[label].quantile(0.30)
q70 = df[label].quantile(0.70)

# Create new dataframe by filtering values between 20th and 30th percentile and between 70th and 80th percentile
df_ood = df[((df[label] <= q30) & (df[label] > q20)) | ((df[label] < q80) & (df[label] >= q70))].copy()
df_ood.loc[df_ood[label] <= q30, label] = 0
df_ood.loc[df_ood[label] >= q70, label] = 1

In [None]:
np.unique(df_extreme[label], return_counts=True)

# Create the different datasets

## Create the original dataset with all the images

In [None]:
class BiopsyDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, df, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.df = df
        self.image_filenames = []
        self.labels = []

        for image_filename in glob.glob(os.path.join(self.root_dir, '*.jpeg')):
            ID_number = image_filename.split('/')[-1].split('.tif')[0]
            if ID_number in df.ID_number.tolist():
                self.image_filenames.append(image_filename)
                self.labels.append(df[df['ID_number'] == ID_number][label].values[0])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image = io.imread(self.image_filenames[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label
    
    def get_label_indices(self):
        # Get indices for labels 0 and 1
        zeros_indices = [i for i, x in enumerate(self.labels) if x == 0]
        ones_indices = [i for i, x in enumerate(self.labels) if x == 1]
        return zeros_indices, ones_indices
    

# Define transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(128),
    transforms.ToTensor(),
])


# Create train, validation and test sets
train_df, temp_df = train_test_split(df_extreme, test_size=0.3, random_state=41, stratify=df_extreme[label])
valid_df, test_df = train_test_split(temp_df, test_size=0.33, random_state=41, stratify=temp_df[label])

# Create datasets
train_dataset = BiopsyDataset("/storage/Chloe/zoom_20_512", train_df, transform=transform)
valid_dataset = BiopsyDataset("/storage/Chloe/zoom_20_512", valid_df, transform=transform)
test_dataset = BiopsyDataset("/storage/Chloe/zoom_20_512", test_df, transform=transform)
test_dataset_ood = BiopsyDataset("/storage/Chloe/zoom_20_512", df_ood, transform=transform)

In [None]:
print(len(train_dataset))
print(len(valid_dataset))
print(len(test_dataset))
print(len(test_dataset_ood))

## Create balanced datasets with label

In [None]:
class SubsetDataset(torch.utils.data.Dataset):
    def __init__(self, subset_indices, original_dataset):
        self.subset_indices = subset_indices
        self.original_dataset = original_dataset

    def __getitem__(self, index):
        original_index = self.subset_indices[index]
        return self.original_dataset[original_index]

    def __len__(self):
        return len(self.subset_indices)

In [None]:
def balanced_data(dataset):
    zero_index, one_index = dataset.get_label_indices()
    min_class_size = min(len(zero_index), len(one_index))
    np.random.seed(42) # Set the random seed for reproducibility
    balanced_indices_class0 = np.random.choice(zero_index, size=min_class_size, replace=False)
    balanced_indices_class1 = np.random.choice(one_index, size=min_class_size, replace=False)
    balanced_indices = np.concatenate([balanced_indices_class0, balanced_indices_class1])
    balanced_dataset = SubsetDataset(balanced_indices, dataset)
    return balanced_dataset

In [None]:
balanced_train_dataset = balanced_data(train_dataset)
balanced_valid_dataset = balanced_data(valid_dataset)
balanced_test_dataset = balanced_data(test_dataset)
balanced_test_ood_dataset = balanced_data(test_dataset_ood)

In [None]:
print(len(balanced_train_dataset))
print(len(balanced_valid_dataset))
print(len(balanced_test_dataset))
print(len(balanced_test_ood_dataset))

## Create unlabeled datasets

In [None]:
class UnlabeledDataset(Dataset):
    def __init__(self, labeled_dataset):
        self.labeled_dataset = labeled_dataset

    def __len__(self):
        return len(self.labeled_dataset)

    def __getitem__(self, idx):
        image, _ = self.labeled_dataset[idx]  # Ignore label
        return image

In [None]:
train_unlabeled_dataset = UnlabeledDataset(balanced_train_dataset)
train_unlabeled_dataloader = DataLoader(train_unlabeled_dataset, batch_size=16, shuffle=True)
valid_unlabeled_dataset = UnlabeledDataset(balanced_valid_dataset)
valid_unlabeled_dataloader = DataLoader(valid_unlabeled_dataset, batch_size=16, shuffle=True)

# Autoencoders

In [None]:
# Here are all the autoencoders needed, 
# Comment all except the one in the latent dimension you want to analyze

In [None]:
# Latent dimension 16,496
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),    
        )
        
    def forward(self, x):
        x = self.encoder(x)
        # x.shape = [16, 64, 16, 16]
        x = self.decoder(x)
        return x


In [None]:
# Latent dimension 4096:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),            

        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),    
        )
        
    def forward(self, x):
        x = self.encoder(x)
        # x.shape = [16, 256, 4, 4]
        x = self.decoder(x)
        return x

In [None]:
# Latent dimension 512
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),      
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 

        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 4, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),    
        )
        
    def forward(self, x):
        x = self.encoder(x)
        # x.shape = [16, 128, 2, 2]
        x = self.decoder(x)
        return x


In [None]:
# Latent dimension 64
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),      
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 4, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),    
        )
        
    def forward(self, x):
        x = self.encoder(x)
        # x.shape = [16, 64, 1, 1]
        x = self.decoder(x)
        return x


In [None]:
# Latent dimension 2
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1),      
            nn.ReLU(),
            nn.Conv2d(8, 4, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(4, 2, kernel_size=3, stride=2, padding=1), 
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2, 4, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),    
        )
        
    def forward(self, x):
        x = self.encoder(x)
        # x.shape = [16, 2, 1, 1]
        x = self.decoder(x)
        return x


In [None]:
def train(model, trainloader, validloader, device, criterion, optimizer, num_epochs):
    best_val_loss = float('inf')
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs = data.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {running_loss / len(trainloader):.4f}")
        
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(validloader, 0):
                inputs = data.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, inputs)
                running_val_loss += loss.item()

        val_loss = running_val_loss / len(validloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
        
        # If this model is better, update best_val_loss and best_model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            print(f"Best validation loss improved to {best_val_loss:.4f}. Saving model...")

        # Plotting original and reconstructed images after each epoch on validation set
        with torch.no_grad():
            inputs = next(iter(validloader)).to(device)
            outputs = model(inputs)
            n = min(inputs.size(0), 5)
            comparison = torch.cat([inputs[:n], outputs.view(-1, 3, 128, 128)[:n]])
            img_grid = torchvision.utils.make_grid(comparison.cpu().detach(), nrow=n)
            plt.figure(figsize=(20, 10))
            plt.imshow(np.transpose(img_grid, (1, 2, 0)))
            plt.title('Original and Reconstructed Images')
            plt.show()

    print("Finished Training")
    return best_model


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
best_model = train(model, train_unlabeled_dataloader, valid_unlabeled_dataloader, device, criterion, optimizer, num_epochs=40)
torch.save(best_model, '/storage/Chloe/final_model_autoencoder/autoencoder_64*16*16_split41_check.pth')


# Load autoencoder

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
autoencoder = Autoencoder().to(device)
autoencoder.load_state_dict(torch.load('/storage/Chloe/final_model_autoencoder/autoencoder_64*16*16_split41_check.pth'))
encoder = autoencoder.encoder

# Create transformed encoded dataset

In [None]:
class TransformedDataset(Dataset):
    def __init__(self, base_dataset, transform):
        self.base_dataset = base_dataset
        self.transform = transform

    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]
        return self.transform(image), label

    def __len__(self):
        return len(self.base_dataset)


In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
transformed_train_dataset = TransformedDataset(balanced_train_dataset, transform)
transformed_valid_dataset = TransformedDataset(balanced_valid_dataset, transform)
transformed_test_dataset = TransformedDataset(balanced_test_dataset, transform)
transformed_test_ood_dataset = TransformedDataset(balanced_test_ood_dataset, transform)

In [None]:
class EncodedDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, encoder):
        self.original_dataset = original_dataset
        self.encoder = encoder

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        original_image, label = self.original_dataset[idx]
        original_image = original_image.to(next(self.encoder.parameters()).device)
        encoded_image = self.encoder(original_image.unsqueeze(0)).squeeze(0).detach()
        return encoded_image, label

In [None]:
encoded_train_dataset = EncodedDataset(transformed_train_dataset, encoder)
encoded_valid_dataset = EncodedDataset(transformed_valid_dataset, encoder)
encoded_test_dataset = EncodedDataset(transformed_test_dataset, encoder)
encoded_test_ood_dataset = EncodedDataset(transformed_test_ood_dataset, encoder)
encoded_test_ood_dataset[0][0].shape

In [None]:
# Create dataloaders
train_dataloader_cl = DataLoader(encoded_train_dataset, batch_size=8, shuffle=True)
valid_dataloader_cl = DataLoader(encoded_valid_dataset, batch_size=8, shuffle=True)
test_dataloader_cl = DataLoader(encoded_test_dataset, batch_size=8, shuffle=True)
test_ood_dataloader_cl = DataLoader(encoded_test_ood_dataset, batch_size=8, shuffle=True)

# Classification 

In [None]:
def evaluate_model(model, dataloader):
    model.eval()  
    true_labels = []
    predictions = []
    prediction_probs = []
    
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        with torch.no_grad():
            outputs = model(inputs).view(-1)
        preds = (torch.sigmoid(outputs) > 0.5).int()
        true_labels.extend(labels.int().tolist())
        predictions.extend(preds.tolist())
        prediction_probs.extend(torch.sigmoid(outputs).tolist())
    
    acc = accuracy_score(true_labels, predictions)
    auc = roc_auc_score(true_labels, prediction_probs)
    f1 = f1_score(true_labels, predictions)
    conf_matrix = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(7,5))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    print(f"AUC-ROC Score: {auc}")
    print(f"F1 Score: {f1}")
    return acc


In [None]:
def train_model(model, train_dataloader, valid_dataloader, criterion, optimizer, num_epochs, filename):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs).view(-1)
            loss = criterion(outputs, labels.float())
            preds = (torch.sigmoid(outputs) > 0.5).int()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += (preds == labels).sum().item()

        epoch_loss = running_loss / len(train_dataloader.dataset)
        epoch_acc = running_corrects / len(train_dataloader.dataset)
        valid_acc = evaluate_model(model, valid_dataloader)

        if valid_acc > best_acc:
            best_epoch = epoch
            best_acc = valid_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), filename)

        print(f'Epoch {epoch}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc}, Valid Acc: {valid_acc}')

    print(f'Best epoch at {best_epoch}')
    model.load_state_dict(best_model_wts)
    return model

## Resnet 50

In [None]:
class MyModelResnet(nn.Module):
    def __init__(self, base_model):
        super(MyModelResnet, self).__init__()
        self.conv = nn.Conv2d(64, 3, kernel_size=1)
        self.base_model = base_model

    def forward(self, x):
        x = self.conv(x)
        x = self.base_model(x)
        return x


# model_base = models.resnet50(pretrained=True)
model_base = models.resnet50(weights=ResNet50_Weights.DEFAULT)
# Modify the final layer of the base models
num_ftrs = model_base.fc.in_features
model_base.fc = nn.Linear(num_ftrs, 1)

model = MyModelResnet(model_base)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer1 = Adam(model.parameters(), lr=0.001)

In [None]:
best_classifier = train_model(
    model, 
    train_dataloader_cl, 
    valid_dataloader_cl, 
    criterion, optimizer1, 
    num_epochs=20, 
    filename='/storage/Chloe/final_models_classification/split41_check_64*16*16_Resnet.pth'
)

In [None]:
test_acc = evaluate_model(best_classifier, test_dataloader_cl)
print(f'Test accuracy: {test_acc}')

In [None]:
ood_test_acc = evaluate_model(best_classifier, test_ood_dataloader_cl)
print(f'Test accuracy on ood: {ood_test_acc}')

## Densenet 121

In [None]:
class MyModelDensenet(nn.Module):
    def __init__(self, base_model):
        super(MyModelDensenet, self).__init__()
        self.conv = nn.Conv2d(64, 3, kernel_size=1)
        self.upsample = nn.Upsample((32, 32))
        self.base_model = base_model

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        x = self.base_model(x)
        return x


model_base = models.densenet121(pretrained=True)
num_ftrs = model_base.classifier.in_features
model_base.classifier = nn.Linear(num_ftrs, 1)
model = MyModelDensenet(model_base)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer1 = Adam(model.parameters(), lr=0.001)

In [None]:
best_classifier = train_model(model, train_dataloader_cl, valid_dataloader_cl, criterion, optimizer1, num_epochs=20, filename='/storage/Chloe/final_models_classification/split41_check_64*16*16_Densenet.pth')

In [None]:
test_acc = evaluate_model(best_classifier, test_dataloader_cl)
print(f'Test accuracy: {test_acc}')

In [None]:
ood_test_acc = evaluate_model(best_classifier, test_ood_dataloader_cl)
print(f'Test accuracy on ood: {ood_test_acc}')

## Inception V3

In [None]:
class MyModelInceptionV3(nn.Module):
    def __init__(self, base_model):
        super(MyModelInceptionV3, self).__init__()
        self.conv = nn.Conv2d(64, 3, kernel_size=1)
        self.upsample = nn.Upsample((299, 299))
        self.base_model = base_model

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        if self.training and self.base_model.aux_logits:
            x, _ = self.base_model(x)
        else:
            x = self.base_model(x)
        return x

model_base = models.inception_v3(pretrained=True, aux_logits=True)

# Inception v3 has two final layers (one main and one auxiliary). 
# We need to change both of them to match our binary classification task
num_ftrs = model_base.fc.in_features
model_base.fc = nn.Linear(num_ftrs, 1)
num_aux_ftrs = model_base.AuxLogits.fc.in_features
model_base.AuxLogits.fc = nn.Linear(num_aux_ftrs, 1)
model = MyModelInceptionV3(model_base)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer1 = Adam(model.parameters(), lr=0.001)

In [None]:
best_classifier = train_model(
    model, 
    train_dataloader_cl, 
    valid_dataloader_cl, 
    criterion, 
    optimizer1, 
    num_epochs=20, 
    filename='/storage/Chloe/final_models_classification/split41_check_64*16*16_Inception.pth'
)

In [None]:
test_acc = evaluate_model(best_classifier, test_dataloader_cl)
print(f'Test accuracy: {test_acc}')

In [None]:
ood_test_acc = evaluate_model(best_classifier, test_ood_dataloader_cl)
print(f'Test accuracy on ood: {ood_test_acc}') 