In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms, utils
import torchvision
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sb
import pandas as pd
from models import VGGClassifier
from utils import train

### Dataloader

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomAffine(translate=(0.05,0.05), degrees=0),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225] ,inplace=True)
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225] ,inplace=True)
])
subset_ratio = 0.2
orig_dataset = ImageFolder(root='./dataset/chest_xray/train/', transform = train_transform)
n = len(orig_dataset)  # total number of examples
print(f"Total Dataset size: {n}")

indices = list(range(n))
# randomly shuffle the indices
np.random.shuffle(indices)

# calculate the split index for the subset
split = int(np.floor(subset_ratio * n))

val_sampler = SubsetRandomSampler(indices[:split])
train_sampler = SubsetRandomSampler(indices[split:])
test_dataset = ImageFolder(root='./dataset/chest_xray/test/', transform = test_transform)

train_dataloader = DataLoader(orig_dataset, batch_size = 16, sampler=train_sampler)
val_dataloader = DataLoader(orig_dataset, batch_size = 8, sampler=val_sampler)
test_dataloader = DataLoader(test_dataset, shuffle = True, batch_size = 1)

### Model Training

In [None]:
model = VGGClassifier()
loss_fn = nn.CrossEntropyLoss()
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 0.0001
optim = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay=1e-3)
model.to(device)
print(device)

In [None]:
train_loss, val_f1, val_acc = train(model, optim, loss_fn, train_dataloader, val_dataloader, epochs=1, early_stop_threshold=4)

In [None]:
train_loss = train_loss
val_f1 = val_f1

plt.subplot(2, 1, 1)
plt.xticks(range(len(train_loss)))
plt.plot(train_loss)
plt.xlabel("Training Steps")
plt.ylabel("Training Loss")


plt.subplot(2, 1, 2)
plt.plot(val_f1)

plt.xlabel("Training Steps")
plt.ylabel("Validation F1 Score")
plt.xticks(range(len(val_f1)))

plt.tight_layout()
plt.savefig("plot_conv_net.pdf")
plt.show()