In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os,gc
import torch.nn.functional as F
import seaborn as sns
import numpy as np  # Import NumPy as np
import random as rnd




In [2]:
custom_data_path = './seg/brain-tumor-mri-dataset'
trainset = ImageFolder(root=os.path.join(custom_data_path, 'Training'), transform=transforms.ToTensor())


In [3]:
print(len(trainset))

5712


In [4]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=False) #CHANGED: No need to load such big batch_size. With bigger datasets it may play a bad joke
data = next(iter(trainloader))[0]
mean = data.mean(axis=(0, 2, 3))
std = data.std(axis=(0, 2, 3))

In [5]:
print(mean,std)

tensor([0.5594, 0.5578, 0.5536]) tensor([0.3917, 0.3917, 0.3950])


In [6]:
normalize = transforms.Normalize(mean=mean, std=std)


In [7]:
# Define data augmentation and normalization transforms


shape = (350,350) # (350,350)
#CHANGED: We can safely transform to 512x512. Even bigger, if training goes good
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(shape),
    transforms.ToTensor(),
   normalize
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(shape),
    normalize
])

In [8]:
batch_size = 16
# Use ImageFolder to create a dataset
#labels are given as 0,1,2
#first folder has label 0 and so on
trainset = ImageFolder(root=os.path.join(custom_data_path, 'Training'), transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

testset = ImageFolder(root=os.path.join(custom_data_path, 'Testing'), transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)



In [9]:
print(len(trainset) + len(testset))

7023


In [10]:
%%capture

# Define the ResNet model

# resnet50 = models.resnet50(pretrained=False, num_classes=4)  # You can choose a different ResNet variant
resnet50 = models.resnet50(pretrained=False)  

#models.wide_resnet50_2
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-3) #CHANGE: We can start with lower LR, but because of scheduler we set high one


import torch.optim.lr_scheduler as lr_scheduler

#CHANGED
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience = 5, verbose=True)

#scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=1e-7, total_iters=1_00)
# ALT SCHEDULER
# scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)

In [11]:
num_epochs = 150
# Check if a CUDA-enabled GPU is available
if torch.cuda.is_available():
    # Set the device to CUDA
    device = torch.device("cuda")
    print("Cuda")
else:
    # If CUDA is not available, use the CPU
    device = torch.device("cpu")


Cuda


In [12]:
def save_model(model,lr,epoch,optimizer):
    checkpoint_path = "./checkpoints/segment_res50x_" + str(epoch) +"_epoch.pth"
    torch.save({
        'epoch': epoch,
        'lr':after_lr,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Model saved at epoch {epoch}") 
    

In [13]:
def validate_accuracy(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Get predicted labels
            _, predicted = torch.max(outputs, 1)

            # Count total and correctly predicted labels
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = (correct / total) * 100.0

    return accuracy


In [14]:
import warnings
warnings.filterwarnings("ignore")

In [15]:
# def checkpoint(model, filename):
#     torch.save(model.state_dict(), filename)
    
def resume(model, filename):
    
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    lr = checkpoint['lr']
    #print(lr)
    return model, optimizer, lr, epoch
    
    #model.load_state_dict(torch.load(filename))

    
def save_model_end(model,after_lr,epoch,optimizer):
    checkpoint_path = "./checkpoints/segment_res50x_" + str(epoch) +"_epoch_TEMP.pth"
    torch.save({
        'epoch': epoch,
        'lr':after_lr,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f"Model saved at epoch {epoch}") 

In [16]:
# !mkdir checkpoints

# !rm -rf checkpoints

!ls -lt checkpoints

total 2300408
-rw-r----- 1 scur0834 scur0834 204821503 Oct 17 16:30  segment_res50x_105_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 17 03:23  segment_res50x_30_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 17 02:48  segment_res50x_15_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821757 Oct 17 02:15  segment_res50x_1_epoch.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 17 02:13  segment_res50x_0_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 22:39  segment_res50x_90_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 21:54  segment_res50x_75_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 21:19  segment_res50x_60_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 20:44  segment_res50x_45_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 17:43  segment_res50_pretrained_0_epoch_TEMP.pth
-rw-r----- 1 scur0834 scur0834 204821503 Oct 16 05:11  segment_res50x_5_epoch_TEMP.pth
-rw-

In [17]:
start_epoch=0

rr = str(input("Do u want to load best model ? Y-yes, N-no "))
if (rr=="yes"):
    print("Continue...")
    resnet50, optimizer, lr, start_epoch = resume(resnet50, "./checkpoints/segment_res50x_105_epoch_TEMP.pth")
    #print(optimizer.param_groups[0]["lr"] )
else:
    print("see ya")

Do u want to load best model ? Y-yes, N-no yes
Continue...


In [18]:
%%time

from tqdm.auto import tqdm

early_stop_thresh = 40
best_accuracy = -1
best_epoch = -1

for epoch in tqdm(range(start_epoch, num_epochs)):
    resnet50.train()
    running_loss = 0.0

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    
    if epoch == 1:  # 50th epoch
        save_model(resnet50,after_lr,str(epoch+start_epoch),optimizer)
        
    
    # Calculate validation accuracy
    accuracy = validate_accuracy(resnet50, testloader, device)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_epoch = epoch
        
        
    elif (epoch - best_epoch >= early_stop_thresh and epoch%early_stop_thresh==0 ):
        print(f"Early stopping after {early_stop_thresh} epochs without improvement in accuracy.")
        save_model(resnet50,after_lr,str(epoch+start_epoch)+" best one",optimizer)
        #checkpoint(model, f"./checkpoints/seg_epoch-{epoch}.pth")
        print("#######\nCurrent best accuracy: ",best_accuracy,"\n#######")
        break    
        
    
    before_lr = optimizer.param_groups[0]["lr"]
    scheduler.step(accuracy)
    after_lr = optimizer.param_groups[0]["lr"]
    
    if (rnd.randint(0, 1)==1):
        torch.cuda.empty_cache()
        del data, labels,outputs, inputs
        gc.collect()
    
    if ((epoch)%5==0):
        print("*****"*15)
        print(f"Epoch {epoch}, Loss: {running_loss / len(trainloader)} , Current best accuracy: {best_accuracy}")
    
        #CHANGED ADDED THIS PART
        print("Epoch %d: SGD lr %.4f -> %.4f" % (epoch, before_lr, after_lr))
        
    
    if (epoch==num_epochs):
        print("#######\nCurrent best accuracy: ",best_accuracy,"\n#######")
        
        save_model_end(resnet50,after_lr,str(epoch+start_epoch)+"_FINISH_AT_END_",optimizer)
        
    elif ((epoch)%15==0):

        save_model_end(resnet50,after_lr,epoch,optimizer)

  0%|          | 0/45 [00:00<?, ?it/s]

***************************************************************************
Epoch 105, Loss: 0.458386659163053 , Current best accuracy: 77.19298245614034
Epoch 105: SGD lr 0.0000 -> 0.0000
Model saved at epoch 105
***************************************************************************
Epoch 110, Loss: 0.4568937724294449 , Current best accuracy: 77.42181540808542
Epoch 110: SGD lr 0.0000 -> 0.0000
***************************************************************************
Epoch 115, Loss: 0.45910767058865365 , Current best accuracy: 77.42181540808542
Epoch 115: SGD lr 0.0000 -> 0.0000
***************************************************************************
Epoch 120, Loss: 0.45610514245614286 , Current best accuracy: 77.42181540808542
Epoch 120: SGD lr 0.0000 -> 0.0000
Model saved at epoch 120
***************************************************************************
Epoch 125, Loss: 0.46667494601896164 , Current best accuracy: 77.42181540808542
Epoch 125: SGD lr 0.0000 -> 0.000

In [46]:
# torch.save(resnet50.state_dict(), f'./checkpoints/shape: {str(shape)} and batch size: {str(batch_size)} segment_resnet50_full.pth')

In [None]:
#resnet_load = models.resnet50(pretrained=False, num_classes=4)  # You can choose a different ResNet variant


In [None]:
#checkpoint_path = '/kaggle/working/res50_normal_5.pth'
#checkpoint = torch.load(checkpoint_path)
# Retrieve the components from the checkpoint
#epoch = checkpoint['epoch']
#model_state_dict = checkpoint['model_state_dict']
#optimizer_state_dict = checkpoint['optimizer_state_dict']
#loss = checkpoint['loss']


In [None]:
#resnet_load.load_state_dict(model_state_dict)
#resnet_load.eval()

In [None]:
# Set the model to evaluation mode

resnet50.eval()

correct = 0
total = 0

# Disable gradient calculation during inference
with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)  # Move data to the appropriate device (e.g., CPU or GPU)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet50(images)
        
        # Get predicted labels
        _, predicted = torch.max(outputs, 1)
        
        # Count total and correctly predicted labels
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate the accuracy
accuracy = (correct / total) * 100.0

print(f'Accuracy of the model on the test dataset: {accuracy:.2f}%')

In [None]:
chest_path = '/kaggle/input/labeled-chest-xray-images'
testset_chest = ImageFolder(root=os.path.join(chest_path, 'chest_xray/test'), transform=transform_test)
testloader_chest = torch.utils.data.DataLoader(testset_chest, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
print(len(testset_chest))

In [None]:
def get_softmax_scores(model, dataloader):
    scores = []
    outputs_raw = []
    device = next(model.parameters()).device  # Get the device of the model's parameters

    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)  # Move inputs to the same device as the model
            outputs = model(inputs)
            softmax_scores = F.softmax(outputs, dim=1)
            scores.append(softmax_scores)
            outputs_raw.append(outputs)

    scores = torch.cat(scores)
    outputs_raw = torch.cat(outputs_raw)
    return scores, outputs_raw

In [None]:
brain_softmax_scores,scores_brain = get_softmax_scores(resnet50, testloader)  # MNIST softmax scores (ID)
print(len(brain_softmax_scores))

chest_softmax_scores,scores_chest = get_softmax_scores(resnet50, testloader_chest)  
print(len(chest_softmax_scores))

In [None]:
threshold = 0.9  # Adjust as needed

def classify_samples(softmax_scores, threshold):
    max_scores, _ = torch.max(softmax_scores, dim=1)
    ood_samples = max_scores < threshold
    return ood_samples,max_scores

In [None]:
brain_ood_samples,max_brain = classify_samples(brain_softmax_scores, threshold)
chest_ood_samples,max_chest = classify_samples(chest_softmax_scores, threshold)

In [None]:
#print(brain_ood_samples,chest_ood_samples)
print(len(brain_softmax_scores))
print(len(chest_softmax_scores))
print(torch.sum(brain_ood_samples).item())
print(torch.sum(chest_ood_samples).item())

In [None]:
# Transfer tensors from GPU to CPU
max_brain_cpu = max_brain.cpu().numpy()
max_chest_cpu = max_chest.cpu().numpy()
# Create a KDE plot
sns.kdeplot(max_brain_cpu, shade=True, color = "blue")
sns.kdeplot(max_chest_cpu, shade=True, color = "yellow")


# Customize the plot
plt.xlabel("Max softmax Score")
plt.ylabel("Density")
plt.title("Softmax Scores Distribution (KDE Plot)")

# Display the plot or save it to a file
plt.show()

In [None]:
d=2
def energy(out, axis = 1, numpy = True, T = 1, correction = False , ty='not_num'): #actually takes negative energy
    print("Called")

    if ty == 'not_num':
        #print("correct")
        #out = out.detach().numpy()

        #print(out)
        #denominator of softmax np.sum(np.exp(out/T)
        scores = T*np.log(np.sum(np.exp(out/T), axis = axis))
    if correction:
        print("here")
        scores -= T*np.log(d)

    return scores

In [None]:
scores_brain_cpu = scores_brain.cpu().numpy()
scores_chest_cpu = scores_chest.cpu().numpy()
# Energy Scores
ID_energy_score = energy(scores_brain_cpu)
OOD_energy_score = energy(scores_chest_cpu)

In [None]:
# Create a KDE plot
sns.kdeplot(ID_energy_score, shade=True, color = "blue")
sns.kdeplot(OOD_energy_score, shade=True, color = "yellow")


# Customize the plot
plt.xlabel("Energy Score")
plt.ylabel("Density")
plt.title("Energy Scores Distribution (KDE Plot)")

# Display the plot or save it to a file
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def thresholdize(distID, distOOD, alpha=0.5, plot=True):
    distID = distID[np.isfinite(distID)]
    distOOD = distOOD[np.isfinite(distOOD)]

    densID, binsID, _ = plt.hist(distID, bins=100, density=True, alpha=0.5, label='ID')
    densOOD, binsOOD, _ = plt.hist(distOOD, bins=100, density=True, alpha=0.5, label='OOD')

    if not plot:
        plt.clf()

    widthID = binsID[1] - binsID[0]
    widthOOD = binsOOD[1] - binsOOD[0]
    n = len(densID)

    low = np.min([binsID[0], binsOOD[0]])
    high = np.max([binsID[-1], binsOOD[-1]])
    thresholds = np.linspace(low, high, n)

    massID = np.zeros(n)
    massOOD = np.zeros(n)

    for i, x in enumerate(thresholds):
        massOOD[i] = np.sum(densOOD[binsOOD[:-1] < x]) * widthOOD
        massID[i] = np.sum(densID[binsID[:-1] >= x]) * widthID

    total = 2 * (alpha * massOOD + (1 - alpha) * massID)
    thresIdx = np.argmax(total)
    threshold = thresholds[thresIdx]
    bestTotal = np.max(total)  # not very intuitive when using alpha-thresholding

    if plot:
        plt.vlines(threshold, 0, 1.1 * np.max(densOOD), label='Threshold', linestyles='dashed')
        plt.legend()

        plt.figure()
        plt.plot(thresholds, massID, label='ID mass right')
        plt.plot(thresholds, massOOD, label='OOD mass left')
        plt.plot(thresholds, total, label='Total')
        plt.vlines(threshold, 0, 1.1 * np.max(total), label='Threshold', linestyles='dashed')
        plt.legend()

        falseOOD = 1.0 - massID[thresIdx]
        falseID = 1.0 - massOOD[thresIdx]

        print(f"Fraction of OOD data falsely classified as ID is {falseID:.3g}")
        print(f"Fraction of ID data falsely classified as OOD is {falseOOD:.3g}")

    return threshold


In [None]:
thr = thresholdize(ID_energy_score,OOD_energy_score)