In [4]:
import torch
from dataset import ImageDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import numpy as np
import torch.optim as optim
from model import  CNN_model
import torch.nn as nn
import numpy as np
from utils_cells import calculate_precision_recall_per_class, get_accuracies_per_class
import sys
from vit import VisionTransformer
import time
import pandas as pd
from sklearn.metrics import confusion_matrix

In [5]:


torch.cuda.empty_cache()

model = VisionTransformer(image_size=32, in_channels=4, num_classes=4, hidden_dims=[16, 16])


batch_size = 64

trainset = ImageDataset(data_path='train_data')
trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=5)


testset =ImageDataset(data_path='validation_data')

testloader = DataLoader(testset, batch_size=batch_size,
                        shuffle=False, num_workers=5)

model = model.to('cuda:0')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

recalls = []
val_recalls = []

losses = []
val_losses = []



def multiclass_accuracy_per_class(outputs, labels, num_classes):
    """
    Calculate per-class accuracy for a multiclass classification problem.

    Args:
    - predictions (torch.Tensor): Model predictions (logits).
    - labels (torch.Tensor): Ground truth labels.
    - num_classes (int): Number of classes.

    Returns:
    - per_class_accuracy (list): List of per-class accuracies.
    """
    outputs = outputs.data.cpu().numpy().argmax(axis=1)
    labels = labels.data.cpu().numpy().argmax(axis=1)

    confusion = confusion_matrix(labels, outputs)
    per_class_accuracy = confusion.diagonal() / confusion.sum(axis=1)
    return per_class_accuracy

In [15]:
for epoch in range(20):
    recall = []
    precision = []

    recall_val = []
    precision_val = []

    training_loss = []
    start_time = time.time()
    elapsed_time = 0
    model.train() 
    total_per_class_accuracy = [0.0] * 4
    for i, data in enumerate(trainloader):

        inputs, labels = data
        labels = torch.Tensor(labels)
        
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        
        per_class_accuracy = multiclass_accuracy_per_class(outputs, labels, 4)
        total_per_class_accuracy = [acc + per_acc for acc, per_acc in zip(total_per_class_accuracy, per_class_accuracy)]
        average_per_class_accuracy = [acc / len(trainloader) if not None else 0.0 for acc in total_per_class_accuracy]

        print(total_per_class_accuracy)

        loss.backward()
        optimizer.step()
        
        training_loss.append(loss.item())


        outputs = outputs.data.cpu().numpy().argmax(axis=1)
        labels = labels.data.cpu().numpy().argmax(axis=1)

        recall_per_class, precision_per_class = calculate_precision_recall_per_class(labels, outputs)
        if (i + 1) % 1000 == 0 or i == len(trainloader) - 1:
            elapsed_time = time.time() - start_time
            batches_done = i + 1
            batches_total = len(trainloader)
            batches_remaining = batches_total - batches_done
            time_per_batch = elapsed_time / batches_done
            estimated_time_remaining = time_per_batch * batches_remaining

            # Convert times to minutes
            elapsed_time_minutes = elapsed_time / 60
            estimated_time_remaining_minutes = estimated_time_remaining / 60



            # Print training progress and estimated time remaining on the same line
            progress_message = f'Batch {i}/{len(trainloader)},Remaining: {estimated_time_remaining_minutes:.2f}min , loss {loss.item()}, class 1: {recall_per_class[0]}, class 2: {recall_per_class[1]}, class 3: {recall_per_class[2]}, class 4: {recall_per_class[3]}'
            sys.stdout.write("\r" + progress_message)
            sys.stdout.flush()


        recall.append(recall_per_class)
        precision.append(precision_per_class)

    average_per_class_accuracy = [acc / len(trainloader) for acc in total_per_class_accuracy]
    for class_idx, acc in enumerate(average_per_class_accuracy):
        print(f'Training Class {class_idx + 1} Accuracy: {acc * 100:.2f}%')
            
 
    model.eval()  
    val_loss = []
    total_per_class_accuracy = [0.0] * 4
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            labels = torch.Tensor(labels)
            inputs = inputs.to('cuda:0')
            labels = labels.to('cuda:0')


            outputs = model(inputs)
            val_loss_crt = criterion(outputs, labels)

            per_class_accuracy = multiclass_accuracy_per_class(outputs, labels, 4)
            total_per_class_accuracy = [acc + per_acc for acc, per_acc in zip(total_per_class_accuracy, per_class_accuracy)]

            val_loss.append(val_loss_crt.item())

            outputs = outputs.data.cpu().numpy().argmax(axis=1)
            labels = labels.data.cpu().numpy().argmax(axis=1)
            
            recall_per_class, precision_per_class = calculate_precision_recall_per_class(labels, outputs)
            
            recall_val.append(recall_per_class)
            precision_val.append(precision_per_class)



    recall = np.mean(np.array(recall), axis=0)
    recall_val = np.mean(np.array(recall_val), axis=0)
    
    recalls.append(recall)
    val_recalls.append(recall_val)
    
    losses.append(training_loss)
    val_losses.append(val_loss)
    
    print(f'Epoch {epoch + 1}, Training loss: {np.mean(training_loss)} Validation Loss: {np.mean(val_loss)}')
    print(f'Epoch {epoch + 1}, Training Class 1: {recall[0]}, Class 2: {recall[1]}, Class 3: {recall[2]}, Class 4: {recall[3]}')
    print(f'Epoch {epoch + 1}, Validation Class 1: {recall_val[0]}, Class 2: {recall_val[1]}, Class 3: {recall_val[2]}, Class 4: {recall_val[3]}')
    average_per_class_accuracy = [acc / len(testloader) for acc  in total_per_class_accuracy]

    for class_idx, acc in enumerate(average_per_class_accuracy):
        print(f'Validation Class {class_idx + 1} Accuracy: {acc * 100:.2f}%')
        

print('Finished Training')
print('Finished Training')


df = pd.DataFrame()

df['loss'] = np.array(losses)
df['val_loss'] = np.array(val_losses)
df.to_csv('results_loss.csv', index=False)



df['recall'] = np.array(recalls)
df['val_recall'] = np.array(val_recalls)

df.to_csv('results.csv', index=False)

torch.save(model.state_dict(),'model_vit1.pth')









  per_class_accuracy = confusion.diagonal() / confusion.sum(axis=1)


[0.7647058823529411, 0.7037037037037037, nan, 0.75]
[1.6737967914438503, 1.1679894179894181, nan]
[2.4998837479655895, 1.7870370370370372, nan]
[3.2035874516692933, 2.523879142300195, nan]
[4.053587451669293, 3.2038791423001953, nan]
[4.664698562780405, 3.912212475633529, nan]
[5.3919712900531325, 4.787212475633529, nan]
[6.044145203096611, 5.576686159844056, nan]
[6.826753898748785, 6.2130497962076925, nan]
[7.660087232082118, 6.832097415255312, nan]
[8.500087232082118, 7.432097415255312, nan]
[9.369652449473422, 8.074954558112454, nan]
[10.202985782806756, 8.759165084428243, nan]
[11.131557211378185, 9.304619629882788, nan]
[12.031557211378185, 9.929619629882788, nan]
[12.81416590703036, 10.694325512235729, nan]
[13.56416590703036, 11.348171666081882, nan]
[14.468927811792264, 11.877583430787764, nan]
[15.37368971655417, 12.43313898634332, nan]
[16.016546859411314, 13.03313898634332, nan]
[16.891546859411314, 13.574805653009985, nan]
[17.786283701516577, 14.003377081581414, nan]
[18.

  per_class_accuracy = confusion.diagonal() / confusion.sum(axis=1)


[22.69786903740532, 18.521202505585325, nan]
[23.38207956372111, 19.229535838918657, nan]
[24.34207956372111, 19.91374636523445, nan]
[25.114806836448384, 20.478963756538796, nan]
[26.01956874121029, 21.248194525769566, nan]
[26.845655697732028, 21.900368438813043, nan]
[27.66047051254684, 22.712868438813043, nan]
[28.478652330728657, 23.331916057860663, nan]
[29.261261026380833, 24.031916057860663, nan]
[30.181261026380835, 24.76875816312382, nan]
[31.02741487253468, 25.340186734552393, nan]
[31.92741487253468, 25.89191087248343, nan]
[32.87478329358731, 26.35344933402189, nan]
[33.724783293587315, 27.090291439285046, nan]
[34.619520135692575, 27.715291439285046, nan]
[35.381424897597334, 28.15973588372949, nan]
[36.048091564264, 28.65973588372949, nan]
[36.890196827421896, 29.114281338274942, nan]
[37.8401968274219, 29.638090862084468, nan]
[38.554482541707614, 30.25713848113209, nan]
[39.31638730361237, 30.85713848113209, nan]
[40.08911457633965, 31.44804757204118, nan]
[40.95275093

  per_class_accuracy = confusion.diagonal() / confusion.sum(axis=1)


[71.07687974062264, 57.26752049322351, nan]
[71.9435464072893, 57.707520493223505, nan]
[72.5344554981984, 58.30752049322351, nan]
[73.41680843937486, 58.8984295841326, nan]
[74.2428953958966, 59.55060349717608, nan]
[74.97973750115976, 60.31983426640684, nan]
[75.61973750115976, 60.95619790277048, nan]
[76.23084861227088, 61.73880659842265, nan]
[77.02251527893755, 62.48880659842265, nan]
[77.88918194560421, 63.29833040794646, nan]
[78.56775337417564, 64.06023516985123, nan]
[79.40108670750897, 64.83296244257849, nan]
[80.24319197066686, 65.42387153348758, nan]
[80.99319197066686, 66.16071363875075, nan]
[81.87554491184332, 66.87499935303646, nan]
[82.70163186836506, 67.44642792446503, nan]
[83.40751422130623, 67.94642792446503, nan]
[84.13828345207546, 68.66071363875074, nan]
[84.83059114438315, 69.18245276918552, nan]
[85.70559114438315, 69.86666329550131, nan]
[86.39524631679694, 70.3428537716918, nan]
[87.39524631679694, 70.9515494238657, nan]
[88.22857965013027, 71.6357599501815,

  per_class_accuracy = confusion.diagonal() / confusion.sum(axis=1)


[108.55410286958838, 86.67211339419039, nan]
[109.26838858387408, 87.35632392050618, nan]
[110.14838858387408, 87.93527128892724, nan]
[110.91761935310485, 88.58233011245665, nan]
[111.75095268643818, 89.12079165091818, nan]
[112.61458905007454, 89.49579165091818, nan]
[113.44792238340787, 90.1148392699658, nan]
[114.24792238340787, 90.88756654269307, nan]


KeyboardInterrupt: 