In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import WeightedRandomSampler
from timm.loss import LabelSmoothingCrossEntropy
from timm.data.mixup import Mixup
import timm
import numpy as np


def get_train_transform():
    return A.Compose([
        A.RandomResizedCrop(224, 224, scale=(0.6, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.RandomRotate90(p=0.3),
        A.ShiftScaleRotate(p=0.3),
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),  
        A.RandomBrightnessContrast(p=0.4),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transform():
    return A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def create_weighted_sampler(dataset):
    class_counts = [len(v) for v in jpg_paths_comb_dict.values()]
    class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
    sample_weights = [class_weights[class_id] for _, class_id in dataset]
    return WeightedRandomSampler(sample_weights, class_weights,  num_samples=len(sample_weights))


def create_model(num_classes, device):
    model = timm.create_model('efficientnet_b2', pretrained=True, num_classes=num_classes)
    

    model.classifier = nn.Sequential(
        nn.Linear(model.classifier.in_features, 512),
        nn.SiLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )
    return model.to(device)

def train_model_enhanced(model, train_loader, val_loader, criterion, optimizer, num_epochs=30):
    mixup_fn = Mixup(mixup_alpha=0.4, cutmix_alpha=0.4, prob=0.6)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.2)
    
    best_f1 = 0.0


    train_losses, train_accuracies = [], []
    val_losses, val_accuracies, val_f1s = [], [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        corrects = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            inputs, labels = mixup_fn(inputs, labels)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            corrects += (predicted == labels.argmax(dim=1)).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = corrects / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)

 
        val_loss, val_acc, val_f1 = validate_model(model, val_loader, criterion)
        scheduler.step(val_f1)

        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        val_f1s.append(val_f1)

   
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), 'best_model.pth')
            print("  🔥 New best model saved!")

    return train_losses, train_accuracies, val_losses, val_accuracies, val_f1s


def validate_model(model, val_loader, criterion):
    model.eval()
    all_preds = []
    all_labels = []
    val_loss = 0.0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            total += labels.size(0)

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    avg_loss = val_loss / total

    return avg_loss, acc, f1







  from .autonotebook import tqdm as notebook_tqdm


In [None]:

from torch.utils.data import Dataset, DataLoader

class CustomImageDataset(Dataset):
    def __init__(self, jpg_files_dict, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_map = {}
        self._prepare_data(jpg_files_dict)

    def _prepare_data(self, jpg_files_dict):
        for label_idx, (subfolder, file_paths) in enumerate(jpg_files_dict.items()):
            self.label_map[subfolder] = label_idx
            for file_path in file_paths:
                self.image_paths.append(file_path)
                self.labels.append(label_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except (OSError, IOError) as e:
            print(f'Error loading image {img_path}: {e}')
            image = Image.new('RGB', (224, 224))
            label = self.labels[idx]
        else:
            label = self.labels[idx]

        if self.transform:
            image = np.array(image)
            transformed = self.transform(image=image)
            image = transformed['image']

        return image, label


In [None]:
import os
root_path = '/raid/ee-mariyam/maryam/abhijeet/Combined_Files'
def get_jpg_paths(base_dir):
    jpg_dict = {}
    
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        
        if os.path.isdir(folder_path):
            jpg_paths = []
            
            for file_name in os.listdir(folder_path):
                if file_name.lower().endswith('.jpg'):
                    file_path = os.path.join(folder_path, file_name)
                    jpg_paths.append(file_path)
            
            if jpg_paths:
                jpg_dict[folder_name] = jpg_paths
    
    return jpg_dict

combined_files_path = root_path
jpg_paths_dict = get_jpg_paths(combined_files_path)


In [4]:
jpg_paths_comb_dict = jpg_paths_dict.copy()
anthra = jpg_paths_comb_dict['Anthracnose']
twist = jpg_paths_comb_dict['Twister']
antra_twist = anthra+twist
jpg_paths_comb_dict['Antracnose_Twister'] = antra_twist

jpg_paths_comb_dict.pop('Anthracnose', None)
jpg_paths_comb_dict.pop('Twister', None)


['/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 06.10.2023 DSC_3185.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 06.10.2023 DSC_3244.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 14.11.2023 DSC_6556.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/S E6 25.10.2023 IMG_20231025_094620957.jpg',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 14.11.2023 DSC_6553.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/S E6 25.10.2023 IMG_20231025_094709204.jpg',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 06.10.2023 DSC_3270.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/D E6 17.11.2023 DSC_6919.JPG',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/S E6 06.10.2023 IMG_20231006_163321~2.jpg',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/S E6 07.10.2023 IMG20231007113350.jpg',
 '/raid/ee-mariyam/maryam/abhijeet/Combined_Files/Twister/S 

In [5]:
import os
import torch
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F 
import cv2
from PIL import Image


In [None]:

if __name__ == "__main__":

    train_transform = get_train_transform()
    val_transform = get_val_transform()
    
    dataset = CustomImageDataset(jpg_paths_comb_dict, transform=train_transform)
    sampler, class_weights = create_weighted_sampler(dataset)
    
    train_loader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=16)
    val_loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = create_model(len(jpg_paths_comb_dict), device)
    
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights.to(device))
 
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

    train_losses, train_accuracies, val_losses, val_accuracies, val_f1s= train_model_enhanced(model, train_loader, val_loader, criterion, optimizer, 40)

Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_18_11_2023_DSC_7118.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_21_11_2023_DSC_7557.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_18_11_2023_DSC_7182.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_21_11_2023_DSC_7631.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_18_11_2023_DSC_7116.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet/Combined_Files/Purple_Blotch/PB_D_E6_21_11_2023_DSC_7698.JPG: broken data stream when reading image file
Error loading image /raid/ee-mariyam/maryam/abhijeet

In [None]:
import matplotlib.pyplot as plt

def plot_training_metrics(train_losses, val_losses, train_accuracies, val_accuracies, val_f1s):
    epochs = range(1, len(train_losses)+1)

    plt.figure(figsize=(16, 4))


    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Val Loss')
    plt.title('Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(epochs, train_accuracies, label='Train Acc')
    plt.plot(epochs, val_accuracies, label='Val Acc')
    plt.title('Accuracy vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()


    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_f1s, label='Val F1 Score', color='purple')
    plt.title('F1 Score vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
plot_training_metrics(train_losses, val_losses, train_accuracies, val_accuracies, val_f1s)


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
import numpy as np

def plot_confusion_and_class_accuracy(model, val_loader, class_names):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    class_acc = cm.diagonal() / cm.sum(axis=1)


    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

    plt.figure(figsize=(10, 4))
    sns.barplot(x=class_names, y=class_acc)
    plt.ylim(0, 1)
    plt.title('Class-wise Accuracy')
    plt.xlabel('Class')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()


In [None]:
plot_confusion_and_class_accuracy(model, val_loader, class_names)
