In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os,gc
import torch.nn.functional as F
import seaborn as sns
import numpy as np  
import random as rnd




In [2]:
custom_data_path = './seg/brain-tumor-mri-dataset'
trainset = ImageFolder(root=os.path.join(custom_data_path, 'Training'), transform=transforms.ToTensor())


In [3]:
print(len(trainset))

5712


In [4]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=False) #CHANGED: No need to load such big batch_size. With bigger datasets it may play a bad joke
data = next(iter(trainloader))[0]
mean = data.mean(axis=(0, 2, 3))
std = data.std(axis=(0, 2, 3))

In [5]:
print(mean,std)

tensor([0.5594, 0.5578, 0.5536]) tensor([0.3917, 0.3917, 0.3950])


In [6]:
normalize = transforms.Normalize(mean=mean, std=std)


In [7]:
# Define data augmentation and normalization transforms


shape = (350,350) # (350,350)
#CHANGED: We can safely transform to 512x512. Even bigger, if training goes good
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(shape),
    transforms.ToTensor(),
   normalize
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(shape),
    normalize
])

In [8]:
batch_size = 16
# Use ImageFolder to create a dataset
#labels are given as 0,1,2
#first folder has label 0 and so on
trainset = ImageFolder(root=os.path.join(custom_data_path, 'Training'), transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

testset = ImageFolder(root=os.path.join(custom_data_path, 'Testing'), transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)



In [9]:
print(len(trainset) + len(testset))

7023


In [10]:
%%capture

# Define the ResNet model

resnet50 = models.resnet50(pretrained=False)  

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4, nesterov=True) 


import torch.optim.lr_scheduler as lr_scheduler

#CHANGED
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience = 5, verbose=True)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)

In [11]:
num_epochs = 150
# Check if a CUDA-enabled GPU is available
if torch.cuda.is_available():
    # Set the device to CUDA
    device = torch.device("cuda")
    print("Cuda")
else:
    # If CUDA is not available, use the CPU
    device = torch.device("cpu")


Cuda


In [12]:
def save_model(model,after_lr,epoch,comments,optimizer):
    checkpoint_path = "./checkpoints/segment_res50x_" + str(epoch) + comments+"_epoch.pth"
    torch.save({
        'epoch': epoch,
        'lr':after_lr,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Model saved at epoch {epoch}") 
    

In [13]:
def validate_accuracy(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Get predicted labels
            _, predicted = torch.max(outputs, 1)

            # Count total and correctly predicted labels
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = (correct / total) * 100.0

    return accuracy


In [14]:
import warnings
warnings.filterwarnings("ignore")

In [15]:

    
def resume(model, filename):
    
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    lr = checkpoint['lr']
    #print(lr)
    return model, optimizer, lr, int(epoch)
    

    
def save_model_end(model,lr,epoch,comments,optimizer):
    checkpoint_path = "./checkpoints/segment_res50x_" + str(epoch) + comments+"_epoch_TEMP.pth"
    torch.save({
        'epoch': epoch,
        'lr':lr,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f"Model saved at epoch {epoch}") 

In [17]:
start_epoch=0

rr = str(input("Do u want to load best model ? Y-yes, N-no "))
if (rr=="yes"):
    print("Continue...")
    resnet50, optimizer, lr, start_epoch = resume(resnet50, "./checkpoints/segment_res50x_65_IMPROVED__epoch_TEMP.pth")
    #print(optimizer.param_groups[0]["lr"] )
else:
    print("No input...")

Do u want to load best model ? Y-yes, N-no yes
Continue...


In [18]:
%%time

from tqdm.auto import tqdm

early_stop_thresh = 40
best_accuracy = -1
best_loss = 1000
best_epoch = -1

for epoch in tqdm(range(start_epoch, num_epochs)):
    resnet50.train()
    running_loss = 0.0

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    
    if epoch == 1:  # 50th epoch
        save_model(resnet50,after_lr,str(epoch),'_1_',optimizer)
        
    
    # Calculate validation accuracy
    accuracy = validate_accuracy(resnet50, testloader, device)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_epoch = epoch

        if (((running_loss / len(trainloader)) < best_loss)  ):
            best_loss=running_loss / len(trainloader)
            print("Upgrading... Best acc: ",best_accuracy," Best Loss: ",best_loss )
            save_model_end(resnet50,optimizer.param_groups[0]["lr"],str(epoch),"_IMPROVED_",optimizer)  
        
        
    elif (epoch - best_epoch >= early_stop_thresh and epoch%early_stop_thresh==0 ):
        print(f"Early stopping after {early_stop_thresh} epochs without improvement in accuracy.")
        save_model(resnet50,after_lr,str(epoch),"_best_one",optimizer)
        #checkpoint(model, f"./checkpoints/seg_epoch-{epoch}.pth")
        print("#######\nCurrent best accuracy: ",best_accuracy,"\n#######")
        break    
        
    
    before_lr = optimizer.param_groups[0]["lr"]
    scheduler.step(accuracy)
    after_lr = optimizer.param_groups[0]["lr"]
    
    if (rnd.randint(0, 1)==1):
        torch.cuda.empty_cache()
        del data, labels,outputs, inputs
        gc.collect()
    
    if ((epoch)%5==0):
        print("*****"*15)
        print(f"Epoch {epoch}, Loss: {running_loss / len(trainloader)} , Current best accuracy: {best_accuracy}")
    
        #CHANGED ADDED THIS PART
        print("Epoch %d: SGD lr %.4f -> %.4f" % (epoch, before_lr, after_lr))
        
    
    if (epoch==num_epochs):
        print("#######\nCurrent best accuracy: ",best_accuracy,"\n#######")
        
        save_model_end(resnet50,after_lr,str(epoch),"_FINISH_AT_END_",optimizer)
        

  0%|          | 0/85 [00:00<?, ?it/s]

Upgrading... Best acc:  84.05797101449275  Best Loss:  0.0612332222996937
Model saved at epoch 65
***************************************************************************
Epoch 65, Loss: 0.0612332222996937 , Current best accuracy: 84.05797101449275
Epoch 65: SGD lr 0.0000 -> 0.0000
Upgrading... Best acc:  85.1258581235698  Best Loss:  0.061193745984390625
Model saved at epoch 66
***************************************************************************
Epoch 70, Loss: 0.06590717011170477 , Current best accuracy: 85.1258581235698
Epoch 70: SGD lr 0.0000 -> 0.0000
Epoch 00008: reducing learning rate of group 0 to 1.0000e-06.
***************************************************************************
Epoch 75, Loss: 0.05856853077861191 , Current best accuracy: 85.1258581235698
Epoch 75: SGD lr 0.0000 -> 0.0000
Epoch 00014: reducing learning rate of group 0 to 1.0000e-07.
***************************************************************************
Epoch 80, Loss: 0.058809943153152466 , 