In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from sklearn.metrics import precision_recall_curve, accuracy_score, precision_score, recall_score, accuracy_score,f1_score, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import torch

class MultimodalDataset(Dataset):
    def __init__(self, csv_file, white_Gamma_dir, Nouv365_dir, Nouv395_dir, white_dir, uv365_dir, uv395_dir, transform=None):
        """
        Args:
            csv_file (str): Path to CSV with columns: [filename, label]
            white_Gamma_dir (str): Directory for white images with Gamma correction
            Nouv365_dir (str): Directory for new 365nm UV images
            Nouv395_dir (str): Directory for new 395nm UV images
            white_dir (str): Directory for standard white images
            uv365_dir (str): Directory for 365nm UV images
            uv395_dir (str): Directory for 395nm UV images
            transform (callable, optional): Transformations to apply to images
        """
        self.data = pd.read_csv(csv_file)
        self.white_Gamma_dir = white_Gamma_dir
        self.Nouv365_dir = Nouv365_dir
        self.Nouv395_dir = Nouv395_dir
        self.white_dir = white_dir
        self.uv365_dir = uv365_dir
        self.uv395_dir = uv395_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['filename']
        img_name = str(int(img_name)) + '.jpg'
        label = self.data.iloc[idx]['label']

        # Load images from all six modalities
        white_img = Image.open(os.path.join(self.white_dir, img_name)).convert('RGB')
        uv365_img = Image.open(os.path.join(self.uv365_dir, img_name)).convert('RGB')
        uv395_img = Image.open(os.path.join(self.uv395_dir, img_name)).convert('RGB')
        white_gamma_img = Image.open(os.path.join(self.white_Gamma_dir, img_name)).convert('RGB')
        nouv365_img = Image.open(os.path.join(self.Nouv365_dir, img_name)).convert('RGB')
        nouv395_img = Image.open(os.path.join(self.Nouv395_dir, img_name)).convert('RGB')

        # Apply transforms if provided
        if self.transform:
            white_img = self.transform(white_img)
            uv365_img = self.transform(uv365_img)
            uv395_img = self.transform(uv395_img)
            white_gamma_img = self.transform(white_gamma_img)
            nouv365_img = self.transform(nouv365_img)
            nouv395_img = self.transform(nouv395_img)

        # Return a tuple of all six images and the label
        images = (white_img, uv365_img, uv395_img, white_gamma_img, nouv365_img, nouv395_img)
        return images, torch.tensor(label, dtype=torch.float32)


In [None]:


class MultimodalModelWithConvFusion(nn.Module):
    def __init__(self, num_classes=1):
        super(MultimodalModelWithConvFusion, self).__init__()
        
        # Backbones for each modality (6 total)
        self.white_backbone = models.resnet18(pretrained=True)
        self.uv365_backbone = models.resnet18(pretrained=True)
        self.uv395_backbone = models.resnet18(pretrained=True)
        self.white_gamma_backbone = models.resnet18(pretrained=True)
        self.nouv365_backbone = models.resnet18(pretrained=True)
        self.nouv395_backbone = models.resnet18(pretrained=True)
        
        # Remove final fully connected layers for all backbones
        self.white_backbone.fc = nn.Identity()
        self.uv365_backbone.fc = nn.Identity()
        self.uv395_backbone.fc = nn.Identity()
        self.white_gamma_backbone.fc = nn.Identity()
        self.nouv365_backbone.fc = nn.Identity()
        self.nouv395_backbone.fc = nn.Identity()
        
        # 1x1 Conv to reduce concatenated features from 3072 -> 512
        self.conv1x1 = nn.Conv2d(in_channels=512 * 6, out_channels=512, kernel_size=1)
        
        # Classifier fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x_white, x_uv365, x_uv395, x_white_gamma, x_nouv365, x_nouv395):
        # Extract features from each backbone: (batch_size, 512)
        f_white = self.white_backbone(x_white)
        f_uv365 = self.uv365_backbone(x_uv365)
        f_uv395 = self.uv395_backbone(x_uv395)
        f_white_gamma = self.white_gamma_backbone(x_white_gamma)
        f_nouv365 = self.nouv365_backbone(x_nouv365)
        f_nouv395 = self.nouv395_backbone(x_nouv395)
        
        # Concatenate all features along channel dimension
        fused = torch.cat([f_white, f_uv365, f_uv395, f_white_gamma, f_nouv365, f_nouv395], dim=1)
        # fused shape: (batch_size, 3072)
        
        # Reshape for conv layer: (batch_size, 3072, 1, 1)
        fused = fused.unsqueeze(-1).unsqueeze(-1)
        
        # 1x1 convolution to reduce channels to 512
        fused = self.conv1x1(fused)  # (batch_size, 512, 1, 1)
        
        # Flatten to (batch_size, 512)
        fused = fused.view(fused.size(0), -1)
        
        # Classifier forward
        out = self.fc(fused)
        
        return out, fused  # prediction and embedding


In [None]:
# Define Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to match model input size
    transforms.RandomHorizontalFlip(),  # Apply random horizontal flip (data augmentation)
    transforms.ToTensor(),  # Convert image to PyTorch tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize using ImageNet stats
])

In [None]:
root = "/home/stud1/Desktop/PIL_MAIN/Leaf Dataset"
csv_file = os.path.join(root, "labels.csv")

white_Gamma_dir = os.path.join(root, "WhiteUV")
Nouv365_dir = os.path.join(root, "365NoUV")      # Assuming folder name "Nouv365"
Nouv395_dir = os.path.join(root, "395NoUV")      # Assuming folder name "Nouv395"
white_dir = os.path.join(root, "WhiteNoUV")          # Assuming folder name "white"
uv365_dir = os.path.join(root, "365UV")
uv395_dir = os.path.join(root, "395UV")

In [None]:
# Define split sizes
total_size = len(dataset)
train_size = int(0.70* total_size)
val_size = int(0.10* total_size)
test_size = total_size - train_size - val_size

seed =42

# Split dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size],generator=torch.Generator().manual_seed(seed))

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16,shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalModelWithConvFusion().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6)

In [None]:
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
from tqdm import tqdm
import torch

def train_one_epoch_with_metrics(model, dataloader, criterion, optimizer, device, epoch, num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    train_progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for (white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs), labels in train_progress_bar:
        # Move all inputs and labels to device
        white_imgs = white_imgs.to(device)
        uv365_imgs = uv365_imgs.to(device)
        uv395_imgs = uv395_imgs.to(device)
        white_gamma_imgs = white_gamma_imgs.to(device)
        nouv365_imgs = nouv365_imgs.to(device)
        nouv395_imgs = nouv395_imgs.to(device)
        labels = labels.to(device).unsqueeze(1)  # Assuming binary classification

        optimizer.zero_grad()
        
        # Forward pass with all 6 inputs
        outputs, _ = model(white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * white_imgs.size(0)
        
        preds = torch.sigmoid(outputs) > 0.5
        correct += (preds == labels.byte()).sum().item()
        total += labels.size(0)

        train_progress_bar.set_postfix(loss=running_loss/total, accuracy=correct/total)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
from tqdm import tqdm
import torch

def validate_with_metrics(model, dataloader, criterion, device, epoch=None, num_epochs=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        desc = f"Validation"
        if epoch is not None and num_epochs is not None:
            desc = f"Epoch {epoch+1}/{num_epochs} Validation"

        val_progress_bar = tqdm(dataloader, desc=desc, leave=False)
        
        for (white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs), labels in val_progress_bar:
            # Move all inputs and labels to device
            white_imgs = white_imgs.to(device)
            uv365_imgs = uv365_imgs.to(device)
            uv395_imgs = uv395_imgs.to(device)
            white_gamma_imgs = white_gamma_imgs.to(device)
            nouv365_imgs = nouv365_imgs.to(device)
            nouv395_imgs = nouv395_imgs.to(device)
            labels = labels.to(device).unsqueeze(1)

            outputs, _ = model(white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * white_imgs.size(0)
            preds = torch.sigmoid(outputs) > 0.5
            correct += (preds == labels.byte()).sum().item()
            total += labels.size(0)

            val_progress_bar.set_postfix(loss=running_loss/total, accuracy=correct/total)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

def test_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for (white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs), labels in tqdm(dataloader):
            # Move inputs and labels to device
            white_imgs = white_imgs.to(device)
            uv365_imgs = uv365_imgs.to(device)
            uv395_imgs = uv395_imgs.to(device)
            white_gamma_imgs = white_gamma_imgs.to(device)
            nouv365_imgs = nouv365_imgs.to(device)
            nouv395_imgs = nouv395_imgs.to(device)
            labels = labels.to(device).unsqueeze(1)

            outputs, _ = model(white_imgs, uv365_imgs, uv395_imgs, white_gamma_imgs, nouv365_imgs, nouv395_imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * white_imgs.size(0)

            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    avg_loss = running_loss / len(dataloader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_preds)

    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Confusion Matrix:\n{cm}")
    print(f"ROC AUC: {roc_auc:.4f}")

    return avg_loss, accuracy, precision, recall, f1, cm, roc_auc

In [None]:
best_val_loss = float('inf')  # Start with infinity as the initial best loss
best_model_path = "./best_model.pth"  # Path to save the best model

In [None]:
num_epochs = 11
best_val_loss = float('inf')
best_model_path = "best_model.pth"

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    # Pass epoch and num_epochs here
    train_loss, train_acc = train_one_epoch_with_metrics(model, train_loader, criterion, optimizer, device, epoch, num_epochs)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    val_loss, val_acc = validate_with_metrics(model, val_loader, criterion, device, epoch, num_epochs)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

torch.cuda.empty_cache()

In [None]:
test_loss, test_acc, test_prec, test_rec, f1, cm, roc_auc = test_model(model, test_loader, criterion, device)
torch.cuda.empty_cache()


In [None]:

# # Plot Loss Curves
# plt.figure(figsize=(12, 5))
# plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
# plt.plot(range(1, num_epochs+1), val_losses, label="Validation Loss")
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.title("Epoch vs Loss")
# plt.legend()
# plt.grid()
# plt.show()

# # Plot Accuracy Curves
# plt.figure(figsize=(12, 5))
# plt.plot(range(1, num_epochs+1), train_accuracies, label="Train Accuracy")
# plt.plot(range(1, num_epochs+1), val_accuracies, label="Validation Accuracy")
# plt.xlabel("Epochs")
# plt.ylabel("Accuracy")
# plt.title("Epoch vs Accuracy")
# plt.legend()
# plt.grid()
# plt.show()

# Save Loss Plot
plt.figure(figsize=(12, 5))
plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
plt.plot(range(1, num_epochs+1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Epoch vs Loss for all modalities")
plt.legend()
plt.grid()
plt.savefig("./loss_all_curve.png")

# Save Accuracy Plot
plt.figure(figsize=(12, 5))
plt.plot(range(1, num_epochs+1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, num_epochs+1), val_accuracies, label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Epoch vs Accuracy for all modalities")
plt.legend()
plt.grid()
plt.savefig("./accuracy_all_curve.png")

class_names = ["Healthy", "Unhealthy"]  # Replace with your class names

# Convert confusion matrix to a heatmap and save it as an image
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix \n all modalities")
plt.savefig("confusion_all_matrix.png")  # Save the image
plt.show()
