# Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# %reload_ext autoreload
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# torch.cuda.memory_summary(device=None, abbreviated=False)
# torch.cuda.empty_cache()

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import time
from itertools import islice
from dataclasses import dataclass
import torchvision
from torchvision.models import densenet161, DenseNet161_Weights, vit_b_16, ViT_B_16_Weights
import os

# pd.set_option('display.max_columns', 500)
# pd.set_option('display.max_rows', 500)
# import warnings
# warnings.filterwarnings('ignore')
# C:/Users/sshar/AppData/Roaming/jupyter/nbextensions/snippets /snippets.json (jupyter --data-dir)

In [2]:
from dataset import CheXpertDataset
import utils
from utils import vprint
from utils import to_gpu

# Configs 

In [20]:
@dataclass
class TrainingConfigs:
    DATA_DIR = os.path.join("..", "data", "CheXpert", "CheXpert-v1.0-small")
    CHECKPOINT_DIR = r"checkpoints"
    BATCH_SIZE = 4
    EPOCHS = 10
    LEARNING_RATE = 0.0001
    CHECKPOINT_TIME_INTERVAL = 5 # seconds
    MODEL_VERSION = "vit_b_16"
    TRAINED_MODEL_PATH = None
    TRAIN_LOADER_SIZE = None
    VALID_LOADER_SIZE = None
    VALID_SIZE = 48 # for debugging purposes

In [21]:
utils.set_seed()

# Training

## Training Setup

In [22]:
train_transform = transforms.Compose([
#     transforms.Resize((320,320)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
#     # augmentation 
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01)], p=0.1),
    transforms.RandomApply([torchvision.transforms.GaussianBlur(kernel_size=(3,3) ,sigma=(0.25, 0.75))], p=0.1),
    torchvision.transforms.RandomAdjustSharpness(sharpness_factor=0.75, p=0.1),
    torchvision.transforms.RandomAdjustSharpness(sharpness_factor=1.25, p=0.1)
])

valid_transform = transforms.Compose([
#     transforms.Resize((320,320)),
    transforms.Resize((224,224)),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [23]:
# Create data loaders.
train_dataset = CheXpertDataset(labels_filename='train.csv', data_dir=TrainingConfigs.DATA_DIR, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=True)
TrainingConfigs.TRAIN_LOADER_SIZE = len(train_dataloader)
len(train_dataset)

223414

In [24]:
valid_dataset = CheXpertDataset(labels_filename='valid.csv', data_dir=TrainingConfigs.DATA_DIR, transform=valid_transform)
valid_dataset.labels = valid_dataset.labels[:TrainingConfigs.VALID_SIZE] # hack for speed debugging
valid_dataloader = DataLoader(valid_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=False)
TrainingConfigs.VALID_LOADER_SIZE = len(valid_dataloader)
len(valid_dataset)

48

In [25]:
# model = densenet161(weights=DenseNet161_Weights.DEFAULT)
# num_features = model.classifier.in_features
# model.classifier = nn.Sequential(
#     nn.Linear(num_features, num_features, bias=True),
#     nn.ReLU(),
#     nn.Dropout(p=0.1),
#     nn.Linear(in_features=num_features, out_features=utils.Configs.NUM_CLASSES, bias=True)
# )

In [26]:
model = torchvision.models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
num_features = model.heads.head.in_features
model.heads.head = nn.Sequential(
    nn.Linear(num_features, num_features, bias=True),
    nn.ReLU(),
    nn.Dropout(p=0.1),
    nn.Linear(in_features=num_features, out_features=utils.Configs.NUM_CLASSES, bias=True)
)

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=TrainingConfigs.LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, mode='min')
criterion = nn.BCEWithLogitsLoss(reduction='mean') # combines BCEntropy and sigmoid
# final nn labels: torch.round(torch.sigmoid(pred))
# simple solution to handle the multi label problem (probabilities don't have to sum to 1)

## Training Loop 

In [28]:
model, results, last_epoch, last_iter = utils.get_previous_training_place(model, TrainingConfigs)
model.train()
model = to_gpu(model)
start_time = time.time()
for epoch in range(last_epoch, TrainingConfigs.EPOCHS):
    train_dataloader_iter = islice(tqdm(enumerate(train_dataloader), total=len(train_dataloader)), 
                                   last_iter+1, len(train_dataloader)) # fast foward dataloader
    for i, (images, labels) in train_dataloader_iter:
        images = to_gpu(images)
        labels = to_gpu(labels)        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        results['train_loss'].append(loss.item())
        if time.time()-start_time > TrainingConfigs.CHECKPOINT_TIME_INTERVAL:
            utils.create_checkpoint(model, epoch, i, valid_dataloader, criterion, results, TrainingConfigs)
            start_time = time.time()
    scheduler.step(np.mean(results["valid_loss"][-len(train_dataloader):]))

2022-07-09 19:41: Loaded model - epoch:0, iter:6


  0%|          | 0/55854 [00:00<?, ?it/s]

2022-07-09 19:41: Only one class present in y_true. ROC AUC score is not defined in that case.
2022-07-09 19:41: 2022_07_09-19_41: Checkpoint Created.
2022-07-09 19:41: Epoch [1/10],   Iter [7/55853],   Train Loss: 0.3646,   Valid Loss: -1.0000,   Valid AUC: -1.0000

2022-07-09 19:42: Only one class present in y_true. ROC AUC score is not defined in that case.
2022-07-09 19:42: 2022_07_09-19_42: Checkpoint Created.
2022-07-09 19:42: Epoch [1/10],   Iter [8/55853],   Train Loss: 0.3681,   Valid Loss: -1.0000,   Valid AUC: -1.0000

2022-07-09 19:42: Only one class present in y_true. ROC AUC score is not defined in that case.
2022-07-09 19:42: 2022_07_09-19_42: Checkpoint Created.
2022-07-09 19:42: Epoch [1/10],   Iter [9/55853],   Train Loss: 0.3767,   Valid Loss: -1.0000,   Valid AUC: -1.0000



KeyboardInterrupt: 