# **Script for data cleaning**

In [None]:
import os
from PIL import Image


directoryName = "/content/drive/MyDrive/CCMT_FInal Dataset";

if os.path.exists(directoryName):
    delete_corrupted_images(directoryName)
    print("Operation complete.")
else:
    print("Directory not found.")

def delete_corrupted_images(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)

            try:
                # Attempt to open the image file
                with Image.open(file_path) as img:
                    pass  # Do nothing if the file is successfully opened
            except (IOError, SyntaxError) as e:
                print(f"Corrupted file: {file_path}. Deleting...")
                os.remove(file_path)

# **TPU Support**

In [None]:
!pip install torch-xla
!pip install torchvision

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

In [None]:
# for using tpu only
DEVICE = xm.xla_device()

# **Necessary** **Libraries**

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
import datetime
from tqdm import tqdm



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

# **Parameters Definition**

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# defines the number of processes that loads the data, idealy set according to the number of the cores available in cpu
NUM_WORKERS = min(4, os.cpu_count()) if DEVICE == "cpu" else 2
BATCH_SIZE = 64;
IMAGE_SIZE = 224
PATH = "/content/drive/MyDrive/AdversialDataset";
TRAIN_RATIO = 0.90

# **Loading Dataset**

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
    ])

full_dataset = datasets.ImageFolder(root=PATH, transform=transform)

In [None]:
train_size = int(TRAIN_RATIO * len(full_dataset))
test_size = len(full_dataset) - train_size

# Split the dataset into training and testing sets
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,num_workers = NUM_WORKERS)
val_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True,num_workers = NUM_WORKERS)


In [None]:
num_classes = len(full_dataset.classes)

# **Model : Resnet-50**



In [None]:
resnet = models.resnet50(pretrained=True)

resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.SGD(resnet.parameters(), lr=0.0000001, momentum=0.009)

# **Loading from Checkpoint**

In [None]:
#loading model states from a checkpoint
resnet = resnet.to(DEVICE)
checkpoint = torch.load('/content/drive/MyDrive/cropsClassifierCheckpoints/checkpoint_11_epoch_lr=0.000001_2.pth',map_location=torch.device(DEVICE                                                                                                                      ));
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
critValidationLoss = checkpoint["clsfValidationLoss"]
critTrainingLoss = checkpoint["critTrainingLoss"]
clsfValidationLoss = checkpoint["clsfValidationLoss"]
clsfTrainingLoss = checkpoint["clsfTrainingLoss"]
# clsfValidationBatch = checkpoint["clsfValidationBatch"]
epochIters   = checkpoint["epochIters"]
lossCriterionList = checkpoint["lossCriterionList"]
lossMisclassificationList = checkpoint["lossMisclassificationList"]
bIters = checkpoint["bIters"]

# **Training**

In [None]:
# these structures maintain the stats for whole epoch

critValidationLoss = [];
critTrainingLoss = [];

# for misclassification stats

clsfValidationLoss = [];
clsfTrainingLoss = [];

clsfValidationBatch = []

epochIters = []
print(f'device {DEVICE} , num_workers {NUM_WORKERS}' )

device cuda , num_workers 2


In [None]:
# these lists maintain the stats for a batch

lossCriterionList = [0];
lossMisclassificationList = [0]
bIters = [0];

In [None]:
def calculateMisclassificationPercentage(logits, labels):
    """
    Calculate the percentage of misclassification given logits and truth labels.

    Parameters:
        logits (torch.Tensor): The predicted logits.
        labels (torch.Tensor): The ground truth labels.

    Returns:
        float: The percentage of misclassification.
    """
    _, predicted_indices = torch.max(logits, 1)
    _, label_indices = torch.max(labels, 1)

    correct = (predicted_indices == label_indices).sum().item()
    total = label_indices.size(0)

    return ((1 - (correct / total)) * 100)


In [None]:
checkpoint_dir = '/content/drive/MyDrive/cropsClassifierCheckpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# **Loop for Training**

In [None]:
num_epochs = 10
resnet.to(DEVICE)

for epoch in range(num_epochs):

    currentTime = datetime.datetime.now()

    timeStamp = currentTime.strftime("%Y%m%d_%H%M%S")

  # checkpoint for saving the model state for resue in case of interruptions
    checkpoint = {
      'epoch': epoch + 1,
      'model_state_dict': resnet.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'critTrainingLoss': critTrainingLoss,
      'clsfTrainingLoss': clsfTrainingLoss,
      'critValidationLoss': critValidationLoss,
      'clsfValidationLoss': clsfValidationLoss,
      'epochIters': epochIters,
      'bIters': bIters,
      'lossCriterionList': lossCriterionList,
      'lossMisclassificationList': lossMisclassificationList,
      }

    epochLoss = 0;
    batchLoss = 0;

    #switching to the training mode, gradients will be kept in record

    resnet.train()

    for batch , (inputs, labels) in enumerate(train_loader):

        inputs = inputs.to(DEVICE)
        # putting labels on accelator and converting them to one hot
        labels = torch.eye(num_classes)[labels].to(DEVICE)
        # [1,3,2,4]

        #  [[0,1,0,0],[0,0,0,1],[0,0,1,0],[0,0,0,]]

        optimizer.zero_grad()
        outputs = resnet(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


        # saving batch loss
        batchLoss = loss.item()
        epochLoss  = epochLoss + batchLoss;
        lossCriterionList.append(batchLoss);
        bIters.append(bIters[-1]+1);

        #calculating the number of correctly classified examples
        lossMisclassificationList.append(calculateMisclassificationPercentage(outputs,labels));

        print(f'Criterion Loss : {batchLoss} Classification Loss : {lossMisclassificationList[-1]} ')


    # plotting performance over entire epoch
    plt.plot(bIters, lossCriterionList, color="green", label="Criterion Loss")
    plt.plot(bIters, lossMisclassificationList, color="blue", label="Classification Loss")
    plt.title("Epoch Loss")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f"/content/drive/MyDrive/cropsClassifierCheckpoints/{timeStamp}_Epoch_Loss_epoch_{epoch}")

    plt.show()

    #calculating the average loss for the whole epoch
    epochLoss = epochLoss/(batch+1);
    # saving the criterion loss
    critTrainingLoss.append(epochLoss);
    #saving the average misclassification loss over whole batch for current epoch
    clsfTrainingLoss.append(sum(lossMisclassificationList))
    epochIters.append(epoch+1);

    #switching to the evaluation mode no gradients graphs computed
    resnet.eval()
    with torch.no_grad():
      # making this list empty after one epcoh
        clsfValidationBatch=[]
        evalLoss = 0;
        for batchIdx ,(inputs, labels) in enumerate(val_loader):

            inputs = inputs.to(DEVICE)
            labels = torch.eye(num_classes)[labels].to(DEVICE)
            outputs = resnet(inputs)



            # calculating the criterion loss can miscalssfication loss
            loss = criterion(outputs, labels)

            evalLoss = evalLoss + loss.item();
            clsfValidationBatch.append(calculateMisclassificationPercentage(outputs,labels))



        evalLoss = evalLoss / (batchIdx+1)
        critValidationLoss.append(evalLoss);
        clsfValidationLoss.append(sum(clsfValidationBatch)/(batchIdx+1));


        # shows how well model performed based on the criterion loss
        plt.title("Criterion Loss")
        plt.plot(epochIters, critTrainingLoss, color="blue", label="Training Loss")
        plt.plot(epochIters, critValidationLoss, color="red", label="Validation Loss")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"/content/drive/MyDrive/cropsClassifierCheckpoints/{timeStamp}_CriterionLoss_epoch_{epoch}")

        plt.show()

        # shows how well model performed based on the misclassfication
        plt.title("Misclassification Loss")
        plt.plot(epochIters, clsfTrainingLoss, color="blue", label="Training Loss")
        plt.plot(epochIters, clsfValidationLoss, color="red", label="Validation Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Mislcassifications")
        plt.legend()
        plt.savefig(f"/content/drive/MyDrive/cropsClassifierCheckpoints/{timeStamp}_MisclassificationLoss_epoch_{epoch}")

        plt.show()

        print(f'Epoch {epoch+1}/{num_epochs},Training Loss {critTrainingLoss[-1]:.4f}, Validation Loss {critValidationLoss[-1]:.4f}\nTraining Misclass {clsfTrainingLoss[-1]:.4f}, Validation Misclass {clsfValidationLoss[-1]:.4f}')

    # saving checkpoint

    checkpoint_filename = f'chkpnt_adv_{timeStamp}_{clsfValidationLoss[-1]}.pth'
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
    torch.save(checkpoint, checkpoint_path)





# **Model Evaluation**

In [None]:


def evaluate_model(model, val_loader, num_classes, criterion, adversialFunc=None):
    model.eval()
    # Variables for storing results
    clsfValidationBatch = []
    evalLoss = 0

    # Process validation data in batches with tqdm progress bar
    for batchIdx, (inputs, labels) in tqdm(enumerate(val_loader), total=len(val_loader), desc="Validation"):


        inputs = inputs.to(DEVICE)
        labels = torch.eye(num_classes)[labels].to(DEVICE)

        # Apply adverserial function if provided
        if adversialFunc is not None:
            outputs = adversialFunc(model,inputs,labels)
        else:
            outputs = model(inputs)

        # Calculate the criterion loss (classification loss)
        loss = criterion(outputs, labels)
        evalLoss = evalLoss + loss.item()

        # Calculate misclassification percentage
        clsfValidationBatch.append(calculateMisclassificationPercentage(outputs, labels))

    # Calculate average loss and misclassification percentage
    evalLoss = evalLoss / (batchIdx + 1)
    clsfValidationLoss = sum(clsfValidationBatch) / (batchIdx + 1)

    return evalLoss, clsfValidationLoss


In [None]:
evalLoss,misCls = evaluate_model(resnet,val_loader,22,criterion)