In [1]:
import os
import time
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from PIL import Image
from tqdm import tqdm
from torch import nn
from torch import optim
from torch.utils.data import DataLoader 

from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.transforms import Resize, Normalize
from torchvision.models import resnet18
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import boto3

In [2]:
#random seed setting
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# data directories initiation
train_data_dir = os.path.join(os.curdir,'..','..','data','preprocessed','classification','train')
val_data_dir = os.path.join(os.curdir,'..','..','data','preprocessed','classification','val')

In [4]:
#defining the pretrained model
model = resnet18(pretrained=True)

In [5]:
# classification layer defination
INPUT_DIM = model.fc.in_features
OUTPUT_DIM = 4

FC_layer = nn.Linear(INPUT_DIM,OUTPUT_DIM)
model.fc = FC_layer
model.fc.weight.requires_grad = True
model.fc.bias.requires_grad = True

In [6]:
#Weieghts freezing
for param in model.parameters():
    param.requires_grad = True

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 11,178,564 trainable parameters


In [10]:
#hyperparametres and setting
lr = 0.001
batch_size = 256
epochs = 10
weight_decay=0
optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
schedular = optim.lr_scheduler.StepLR(optimizer, gamma=0.5,step_size=3,verbose=True)
scaler = torch.cuda.amp.GradScaler()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)

Adjusting learning rate of group 0 to 1.0000e-02.


In [11]:
# related transformation defination
IMAGE_NET_MEANS = [0.485, 0.456, 0.406]
IMAGE_NET_STDEVS = [0.229, 0.224, 0.225]


transforms = Compose([
                    Resize(224),
                    Lambda(lambda x: x.convert('RGB')),
                    ToTensor(),
                    Normalize(IMAGE_NET_MEANS,IMAGE_NET_STDEVS)
])

In [12]:
# Data loading and labeling
train_data = ImageFolder(root= train_data_dir,
                         transform= transforms,
                         )

val_data = ImageFolder(root= val_data_dir,
                       transform= transforms,
                       )

In [13]:
print('Train data classes: ', train_data.class_to_idx,'\n')
print('Val data classes: ', val_data.class_to_idx)

Train data classes:  {'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3} 

Val data classes:  {'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3}


In [14]:
#data iterator defination

train_iterator = DataLoader(train_data,
                            shuffle = True,
                            batch_size=batch_size)

val_iterator = DataLoader(val_data,
                          shuffle = True,
                          batch_size=batch_size)

In [15]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [16]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [17]:
def train(model, iterator, optimizer, criterion, device,schedular ,scaler= False):
    print('training')
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for (image, label) in tqdm(iterator):
        
        image = image.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        
        if scaler:
            
            with torch.cuda.amp.autocast():     
                
                label_pred = model(image)
                loss = criterion(label_pred, label)
                assert label_pred.dtype is torch.float16
                
        else:
            label_pred = model(image)
            loss = criterion(label_pred, label)
        
        acc = calculate_accuracy(label_pred, label)
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        else:
            loss.backward()
            optimizer.step()
        
    schedular.step()

        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [18]:
def evaluate(model, iterator, criterion, device):
    print('validating')
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
        
        for (image, label) in tqdm(iterator):

            image = image.to(device)
            label = label.to(device)

            label_pred = model(image)

            loss = criterion(label_pred, label)

            acc = calculate_accuracy(label_pred, label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [21]:
criterion = criterion.to(device)
best_valid_loss = float('inf')
model_name = 'pretrained_resnet18_weights'
log = pd.DataFrame(columns=['train_loss','train_acc' ,'val_loss', 'val_acc'])

for epoch in range(epochs):
    
    start_time = time.monotonic()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion,device,schedular,scaler=False)
    val_loss, val_acc = evaluate(model, val_iterator, criterion, device)
        
    if val_loss < best_valid_loss:
        best_valid_loss = val_loss
        torch.save(model.state_dict(), model_name)

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    log.loc[len(log.index)] = [train_loss,train_acc,val_loss,val_acc]
    log.to_csv('log.csv')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s, current time: {time.ctime()}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {val_loss:.3f} |  Val. Acc: {val_acc*100:.2f}%')

  0%|          | 0/2268 [00:00<?, ?it/s]

training


  0%|          | 5/2268 [00:50<6:20:46, 10.10s/it] 


KeyboardInterrupt: 