In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [19]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import transforms, utils
from torchmetrics.functional import auc
import random
import datetime
import time
from itertools import islice
from collections import defaultdict
# pd.set_option('display.max_columns', 500)
# pd.set_option('display.max_rows', 500)
# import warnings
# warnings.filterwarnings('ignore')
# C:/Users/sshar/AppData/Roaming/jupyter/nbextensions/snippets /snippets.json (jupyter --data-dir)

In [3]:
from dataset import CheXpertDataset
from utils import *

In [4]:
@dataclass
class TrainingConfigs:
    BATCH_SIZE = 4
    EPOCHS = 3
    LEARNING_RATE = 0.0001
    CHECKPOINT_TIME_INTERVAL = 8 # seconds
    CHECKPOINT_DIR = r"model_checkpoints"
    MODEL_VERSION = "densenet121"
    TRAINED_MODEL_PATH = "2022_06_11-19_29__densenet121__epoch-0__iter-2__trainLastLoss-1.3976__validAUC-0.013.dict"

In [5]:
torch.manual_seed(Configs.SEED)
random.seed(Configs.SEED)
np.random.seed(Configs.SEED)

In [6]:
def to_gpu(x):
    return x.cuda() if torch.cuda.is_available() else x

In [7]:
def get_time_str():
    time_str = str(datetime.datetime.now())[:-10]
    trans = str.maketrans("-: ","__-")
    return time_str.translate(trans)

def create_checkpoint(model, epoch, i, valid_dataloader, criterion, results):
    valid_loss, valid_auc = calc_auc_score(model, valid_dataloader, criterion)
    results['valid_loss'].append(loss.item())
    results['valid_auc'].append(valid_auc)
    metadata = {
        "epoch": epoch,
        "iter": i,
        "trainLastLoss": np.mean(results["train_loss"][-100:]),
        "validAUC": results["valid_auc"][-1]
    }
    time_str = get_time_str()
    metadata_suffix = '__'.join([f"{k}-{round(v,4)}" for k, v in metadata.items()])
    filename = f"{time_str}__{TrainingConfigs.MODEL_VERSION}__{metadata_suffix}.dict"
    filepath = os.path.join(TrainingConfigs.CHECKPOINT_DIR, filename)
    statedata = {**metadata, **{"model": model.state_dict(), "results": results}}
    torch.save(statedata, filepath)
    print(f"{time_str}: Checkpoint Created.")

In [8]:
def avg_auc(outputs, labels):
    probas = softmax(outputs).T
    return np.mean([auc(y_proba, y_true, reorder=True) for y_proba, y_true in zip(probas, labels.T)])


def calc_auc_score(model, dataloader, criterion=None):
    all_labels = []
    all_outputs = []
    model.eval()
    for i, (images, labels) in enumerate(dataloader):
        images = to_gpu(images)
        outputs = model(images).cpu()
        all_outputs.append(outputs)
        labels = labels.cpu()
        all_labels.append(labels)
    all_outputs, all_labels = torch.cat(all_outputs), torch.cat(all_labels)
    auc_value = avg_auc(all_outputs, all_labels)
    if auc_value > 1:
        print(all_outputs, all_labels)
        input()
    loss_value = None
    if criterion:
        loss_value = criterion(all_outputs, all_labels)
    model.train()
    return loss_value, auc_value

In [9]:
train_transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
valid_transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [22]:
# Create data loaders.
train_dataset = CheXpertDataset(mode='train', transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=True)
len(train_dataset)

223414

In [11]:
valid_dataset = CheXpertDataset(mode='valid', transform=valid_transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=False)
len(valid_dataset)

12

In [12]:
torch.hub._validate_not_a_forked_repo = lambda a,b,c: True # workaround for torch.hub
model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)
model.classifier = nn.Linear(in_features=1024, out_features=Configs.NUM_CLASSES, bias=True) # updating model output dim

Using cache found in C:\Users\sshar/.cache\torch\hub\pytorch_vision_v0.10.0


In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=TrainingConfigs.LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)

In [24]:
def get_previos_training_place(model_filename):
    if not model_filename:
        return 0, -1
    metadata = model_filename.split("__")
    last_epoch, last_iter = int(metadata[2][6:]), int(metadata[3][5:])
    return last_epoch, last_iter

In [25]:
last_epoch, last_iter = get_previos_training_place(TrainingConfigs.TRAINED_MODEL_PATH)
last_epoch, last_iter

(0, 2)

In [26]:
results = {
    "train_loss": [],
    "valid_loss": [],
    "valid_auc": []
}
checkpoint_time_int = 10 # seconds
start_time = time.time()
model.train()
for epoch in range(last_epoch, TrainingConfigs.EPOCHS):
    train_dataloader_iter = islice(tqdm(enumerate(train_dataloader), total=len(train_dataloader)), 
                                   last_iter+1, len(train_dataloader)) # fast foward dataloader
    for i, (images, labels) in train_dataloader_iter:
        images = to_gpu(images)
        labels = to_gpu(labels)        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        results['train_loss'].append(loss.item())
        if time.time()-start_time > TrainingConfigs.CHECKPOINT_TIME_INTERVAL:
            create_checkpoint(model, epoch, i, valid_dataloader, criterion, results)
            print('Epoch [%d/%d],   Iter [%d/%d],   Train Loss: %.4f,   Valid Loss: %.4f,   Valid AUC: %.4f' 
                   %(epoch+1, TrainingConfigs.EPOCHS,
                     i+1, len(train_dataloader), 
                     np.mean(results["train_loss"][-100:]),
                     results["valid_loss"][-1],
                     results["valid_auc"][-1]),
                 end="\n\n")
            start_time = time.time()

  0%|          | 0/55854 [00:00<?, ?it/s]

2022_06_11-20_08: Checkpoint Created.
Epoch [1/3],   Iter [6/55854],   Train Loss: 2.4056,   Valid Loss: 2.3492,   Valid AUC: 0.0020

2022_06_11-20_08: Checkpoint Created.
Epoch [1/3],   Iter [9/55854],   Train Loss: 2.3796,   Valid Loss: 1.1595,   Valid AUC: 0.0016

2022_06_11-20_08: Checkpoint Created.
Epoch [1/3],   Iter [12/55854],   Train Loss: 2.0778,   Valid Loss: 0.8529,   Valid AUC: 0.0019



KeyboardInterrupt: 