This notebook is used to perform
- disease classification from Lettuce NPK dataset
- perform regression to predict RGB intensities on the same

In [1]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [None]:
# Opting Dataset class using ImageFolder as it is easier for classification task
import os
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),         
    transforms.Normalize(           
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

dataset = datasets.ImageFolder(root= os.path.join(os.getcwd(),"dataset"), transform=transform)
print(dataset.class_to_idx)
print(len(dataset))

{'-K': 0, '-N': 1, '-P': 2, 'FN': 3}
208


In [3]:
# Dataloader Part
from torch.utils.data import DataLoader, random_split

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = 1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

166
20
22


In [None]:
# Checking train dataset distribution
import collections

train_indices = train_dataset.indices  
train_labels = [dataset.targets[i] for i in train_indices]  

class_counts = collections.Counter(train_labels)

print("Class distribution in training set:")
for class_id, count in class_counts.items():
    class_name = list(dataset.class_to_idx.keys())[list(dataset.class_to_idx.values()).index(class_id)]
    print(f"Class '{class_name}' (id={class_id}): {count} samples")

Class distribution in training set:
Class '-N' (id=1): 49 samples
Class '-K' (id=0): 54 samples
Class '-P' (id=2): 54 samples
Class 'FN' (id=3): 9 samples


In [None]:
# Initializing Loss, Optimizer and model
import timm
from torch import nn
from torch.optim import Adam,SGD
from torchvision import models
from torch.optim.lr_scheduler import StepLR  

class DualHeadResNet(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.backbone = models.resnet50(weights='ResNet50_Weights.DEFAULT')
        self.backbone.fc = nn.Identity()  # Remove final fc layer
        self.class_head = nn.Linear(2048, num_classes)
        
        # GT is on Normalized values 
        # self.reg_head = nn.Sequential(
        #                 nn.Linear(2048, 512),
        #                 nn.ReLU(),
        #                 nn.Linear(512, 3),
        #                 nn.Sigmoid() # Output range: restricted to [0,1]
        #                 )

        # If GT is on unscaled unnormalized values
        self.reg_head = nn.Sequential(
        nn.Linear(2048, 512),
        nn.ReLU(),             
        nn.Linear(512, 3)       # Output range: unrestricted
        )


    def forward(self, x):
        features = self.backbone(x)  
        out_cls = self.class_head(features)
        out_reg = self.reg_head(features)
        return out_cls, out_reg

model = DualHeadResNet(num_classes=3).to(device)
criterion_cls = nn.CrossEntropyLoss()
criterion_reg = nn.SmoothL1Loss()
optimizer = Adam(model.parameters(), betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4)
# scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

lambda_reg = 0.0001
lambda_cls = 1


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Uncomment the below cell to try EfficientNet

# class DualHeadEfficientNet(nn.Module):
#     def __init__(self, backbone_name='efficientnet_b0', num_classes=3):
#         super().__init__()
#         self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0)  # remove head
#         num_features = self.backbone.num_features

#         self.class_head = nn.Linear(num_features, num_classes)
#         self.reg_head = nn.Sequential(
#             nn.Linear(num_features, 512),
#             nn.ReLU(),
#             nn.Linear(512, 3)
#         )

#     def forward(self, x):
#         features = self.backbone(x)
#         out_cls = self.class_head(features)
#         out_reg = self.reg_head(features)
#         return out_cls, out_reg

# model = DualHeadEfficientNet(num_classes=3).to(device)

In [None]:
# For plain Classification, no reg head
# model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# model.fc = nn.Linear(model.fc.in_features, 4)

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

# Training Loop
num_epochs = 5
# best_train_loss = float('inf')
best_val_loss = float('inf')

train_loss_list = []
val_loss_list = []
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for image, label in tqdm(train_loader):
        image = image.to(device)
        label = label.to(device)
        # output = model(image)

        out_cls, out_reg = model(image) 

        # Unnormalize GT for computing RGB_GT
        mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=image.device).view(1, 3, 1, 1)
        unnorm_image = image * std + mean
        unnorm_image = unnorm_image * 255.0
        rgb_gt = unnorm_image.mean(dim=[2, 3])


        # rgb_gt = image.mean(dim=[2, 3])  # To try  Ground truth RGB means on scaled, normalized
        loss_reg = criterion_reg(out_reg, rgb_gt)
        loss = lambda_cls * loss_cls + lambda_reg * loss_reg
        

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss/len(train_loader)
    train_loss_list.append(avg_train_loss)
    

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for image, label in tqdm(val_loader):
            image = image.to(device)
            label = label.to(device)

            # output = model(image)
            # loss = criterion(output, label)

            out_cls, out_reg = model(image) 
            mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).view(1, 3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=image.device).view(1, 3, 1, 1)
            unnorm_image = image * std + mean
            unnorm_image = unnorm_image * 255.0
            rgb_gt = unnorm_image.mean(dim=[2, 3])

            # rgb_gt = image.mean(dim=[2, 3])  # To try  Ground truth RGB means on normalized
            loss_cls = criterion_cls(out_cls, label)
            loss_reg = criterion_reg(out_reg, rgb_gt)
            loss = lambda_cls * loss_cls + lambda_reg * loss_reg


            val_loss += loss.item()


    avg_val_loss = val_loss/len(val_loader)
    val_loss_list.append(avg_val_loss)
    scheduler.step(avg_val_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss / len(train_loader):.4f} | Val Loss: {val_loss / len(val_loader):.4f}")

    if epoch%5 == 0:
        print("Pred:", out_reg[0].cpu().detach().numpy())
        print("GT:", rgb_gt[0].cpu().detach().numpy())


    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_val_loss,
        }, os.path.join(os.getcwd(), rf'models\test1.pt'))  # Save checkpoint to file
        print(f"Best model saved at epoch {epoch+1} with Val Loss: {avg_val_loss:.4f}")


In [None]:
# PLot the learning Curves
X = [x+1 for x in range(num_epochs)]
plt.plot(X,train_loss_list, c='r', label= "Training Curve")
plt.plot(X,val_loss_list, c='b', label = "Validation Curve")
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.xticks([x for x in range(0,num_epochs,5)])
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# Generate Predictions
import torch
from torchmetrics.classification import MulticlassConfusionMatrix
from torchmetrics.classification import Precision, Recall, F1Score
from torchmetrics.classification import Accuracy
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, R2Score


# Load the model checkpoint
# model_path = os.path.join(os.getcwd(), rf'models\best_model.pth')
checkpoint = torch.load(os.path.join(os.getcwd(), rf'models\40_images_cls_vl_5_epochs.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)  # Make sure the model is moved to the correct device (CPU/GPU)


# Initialize metrics
num_classes = 3  # Assuming 4 classes for your classification task (adjust based on your task)
confusion_matrix = MulticlassConfusionMatrix(num_classes=num_classes).to(device)
precision = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
recall = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
f1_score = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)
accuracy = Accuracy(task="multiclass", num_classes=num_classes).to(device)
# Regression metrics (for RGB intensity regression)
mse = MeanSquaredError().to(device)
mae = MeanAbsoluteError().to(device)
r2 = R2Score().to(device)


# Initialize storage for labels and predictions
all_labels = []
all_preds = []
all_out_reg = []
all_rgb_gt = []
# Evaluation
model.eval()
with torch.no_grad():
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)
        
        # Get model predictions
        # output = model(image)
        out_cls, out_reg = model(image)
        _, preds = torch.max(out_cls, 1)

        mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=image.device).view(1, 3, 1, 1)
        unnorm_image = image * std + mean
        unnorm_image = unnorm_image * 255.0
        rgb_gt = unnorm_image.mean(dim=[2, 3])

        # rgb_gt = image.mean(dim=[2, 3]) # If scaled normalized
        
        
        # Store predictions and true labels
        all_labels.append(label)
        all_preds.append(preds)
        all_out_reg.append(out_reg)
        all_rgb_gt.append(rgb_gt)

        # Update metrics
        confusion_matrix.update(preds, label)
        precision.update(preds, label)
        recall.update(preds, label)
        f1_score.update(preds, label)
        accuracy.update(preds, label)
        # Update regression metrics
        mse.update(out_reg, rgb_gt)
        mae.update(out_reg, rgb_gt)
        r2.update(out_reg, rgb_gt)

# Compute final metrics
confusion_matrix_result = confusion_matrix.compute()
precision_result = precision.compute()
recall_result = recall.compute()
f1_score_result = f1_score.compute()
accuracy_result = accuracy.compute()
# Compute final Regression Metrics
mse_result = mse.compute()
mae_result = mae.compute()
r2_result = r2.compute()


print("Confusion Matrix:")
print(confusion_matrix_result)
print(f"Precision: {precision_result:.4f}")
print(f"Recall: {recall_result:.4f}")
print(f"F1 Score: {f1_score_result:.4f}")
print(f"Accuracy: {accuracy_result:.4f}")

print(f"Mean Squared Error (MSE): {mse_result:.4f}")
print(f"Mean Absolute Error (MAE): {mae_result:.4f}")
print(f"R² Score: {r2_result:.4f}")



In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchmetrics.classification import MulticlassConfusionMatrix


# Convert to numpy
confmat_np = confusion_matrix_result.cpu().numpy()
classes = ['-N','-P','FN'] # {'-K': 0, '-N': 1, '-P': 2, 'FN': 3}   ['-N','-P','FN']
# Plot the heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(confmat_np, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=[f"Pred {i}" for i in classes],
            yticklabels=[f"True {i}" for i in classes])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix Heatmap")
plt.tight_layout()
plt.show()
