In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %reload_ext autoreload

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import time
from itertools import islice
from dataclasses import dataclass

# pd.set_option('display.max_columns', 500)
# pd.set_option('display.max_rows', 500)
# import warnings
# warnings.filterwarnings('ignore')
# C:/Users/sshar/AppData/Roaming/jupyter/nbextensions/snippets /snippets.json (jupyter --data-dir)

In [None]:
from dataset import CheXpertDataset
import utils
from utils import to_gpu

In [None]:
@dataclass
class TrainingConfigs:
    BATCH_SIZE = 4
    EPOCHS = 3
    LEARNING_RATE = 0.0001
    CHECKPOINT_TIME_INTERVAL = 8 # seconds
    CHECKPOINT_DIR = r"model_checkpoints"
    MODEL_VERSION = "densenet121"
    TRAINED_MODEL_PATH = None

In [None]:
utils.set_seed()

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
valid_transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
# Create data loaders.
train_dataset = CheXpertDataset(mode='train', transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=True)
len(train_dataset)

223414

In [None]:
valid_dataset = CheXpertDataset(mode='valid', transform=valid_transform)
valid_dataset.labels = valid_dataset.labels[:12] # hack for speed debugging
valid_dataloader = DataLoader(valid_dataset, batch_size=TrainingConfigs.BATCH_SIZE, shuffle=False)
len(valid_dataset)

12

In [None]:
torch.hub._validate_not_a_forked_repo = lambda a,b,c: True # workaround for torch.hub
model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)
model.classifier = nn.Linear(in_features=1024, out_features=utils.Configs.NUM_CLASSES, bias=True) # updating model output dim

Using cache found in C:\Users\sshar/.cache\torch\hub\pytorch_vision_v0.10.0


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=TrainingConfigs.LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [None]:
model, results, last_epoch, last_iter = utils.get_previos_training_place(model, TrainingConfigs)
last_epoch, last_iter

(0, 17)

In [None]:
start_time = time.time()
model.train()
for epoch in range(last_epoch, TrainingConfigs.EPOCHS):
    train_dataloader_iter = islice(tqdm(enumerate(train_dataloader), total=len(train_dataloader)), 
                                   last_iter+1, len(train_dataloader)) # fast foward dataloader
    for i, (images, labels) in train_dataloader_iter:
        images = to_gpu(images)
        labels = to_gpu(labels)        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        results['train_loss'].append(loss.item())
        if time.time()-start_time > TrainingConfigs.CHECKPOINT_TIME_INTERVAL:
            utils.create_checkpoint(model, epoch, i, valid_dataloader, criterion, results, TrainingConfigs)
            print('Epoch [%d/%d],   Iter [%d/%d],   Train Loss: %.4f,   Valid Loss: %.4f,   Valid AUC: %.4f' 
                   %(epoch+1, TrainingConfigs.EPOCHS,
                     i, len(train_dataloader)-1, 
                     np.mean(results["train_loss"][-100:]),
                     results["valid_loss"][-1],
                     results["valid_auc"][-1]),
                 end="\n\n")
            start_time = time.time()

  0%|          | 0/55854 [00:00<?, ?it/s]

2022_06_14-14_43: Checkpoint Created.
Epoch [1/3],   Iter [19/55853],   Train Loss: 1.5158,   Valid Loss: 1.5201,   Valid AUC: 0.0068



KeyboardInterrupt: 