In [None]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
## Testing the model by selecting a random image from the image folders
import random
from PIL import Image
import os
import timm
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
def coral_loss(source, target):
    """ Deep CORAL loss to align feature distributions """
    d = source.size(1)  # Feature dimension
    source_coral = torch.matmul((source - source.mean(dim=0)).T, (source - source.mean(dim=0))) / (source.size(0) - 1)
    target_coral = torch.matmul((target - target.mean(dim=0)).T, (target - target.mean(dim=0))) / (target.size(0) - 1)
    loss = torch.norm(source_coral - target_coral, p='fro')**2 / (4 * d**2)
    return loss

def new_domain_accuracy(model, new_domain_loader, device):
    misclassified_images = []
    misclassified_labels = []
    misclassified_preds = []

    with torch.no_grad():
        for images, labels in new_domain_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            # Find misclassified images
            misclassified_idx = (predicted != labels).nonzero(as_tuple=True)[0]
            for idx in misclassified_idx:
                misclassified_images.append(images[idx].cpu())
                misclassified_labels.append(labels[idx].cpu())
                misclassified_preds.append(predicted[idx].cpu())
        print("Accuracy for new domain: "+str((1-len(misclassified_images)/len(new_domain_loader.dataset))*100))


def train_domain_adaptation(model, source_loader, target_loader, criterion, optimizer, lambda_coral=0.1, epochs=20):
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        correct_predictions = 0
        total_samples = 0

        for (source_images, source_labels), (target_images, _) in zip(source_loader, target_loader):
            source_images, source_labels = source_images.to(device), source_labels.to(device)
            target_images = target_images.to(device)

            optimizer.zero_grad()

            # Forward pass for source (labeled)
            source_outputs = model(source_images)
            classification_loss = criterion(source_outputs, source_labels)

            # Forward pass for target (unlabeled)
            source_features = model.forward_features(source_images)
            target_features = model.forward_features(target_images)

            # Flatten spatial features: [32, 512, 7, 7] → [32, 512 * 7 * 7]
            source_features = source_features.view(source_features.size(0), -1)
            target_features = target_features.view(target_features.size(0), -1)
            # Compute CORAL loss
            coral_loss_val = coral_loss(source_features, target_features)

            # Total loss = classification loss + CORAL loss
            loss = classification_loss + lambda_coral * coral_loss_val
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Compute accuracy for source domain
            _, predicted = torch.max(source_outputs, 1)
            total_samples += source_labels.size(0)
            correct_predictions += (predicted == source_labels).sum().item()

        avg_loss = total_loss / len(source_loader)
        accuracy = correct_predictions / total_samples * 100
        new_domain_accuracy(model, test_target_loader, device)
        print(f"TRAINING: Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%, CORAL Loss: {coral_loss_val.item():.4f}")


# Parameters
data_dir = '/content/DAPlankton_subset/CS'
train_size = 0.4  # Percentage of data for training
#test_size = 0.1   # Percentage of data for testing

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),              # Normalize image size
    transforms.RandomHorizontalFlip(),          # Add augmentation for robustness
    transforms.RandomAffine(degrees=15,         # Slight rotation to reduce sensitivity
                            scale=(0.8, 1.2)),  # Random zoom-in/out to balance scale difference
    transforms.ColorJitter(brightness=0.2,      # Adjust brightness to balance intensity differences
                           contrast=0.2, 
                           saturation=0.2), 
    transforms.Grayscale(num_output_channels=3),# Convert all to 3-channel grayscale for consistency
    transforms.GaussianBlur(kernel_size=3),     # Reduce sharpness differences
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],  # Normalize pixel intensities
                         std=[0.5, 0.5, 0.5]),
])

# Load dataset from folder structure
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Split dataset into training and test sets
train_len = int(len(dataset) * train_size)
test_len = len(dataset) - train_len
train_dataset, test_dataset = random_split(dataset, [train_len, test_len])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model creation
num_classes = 3

# Load pre-trained ResNet-18 model
model = timm.create_model('resnet18', pretrained=True,num_classes=3)

# Modify classifier for your dataset
#model.fc = nn.Linear(model.fc.in_features, num_classes)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

target_data_dir = '/content/DAPlankton_subset/FC'  # Change this to your target domain folder
target_dataset = datasets.ImageFolder(root=target_data_dir, transform=transform)

# Split dataset into training and test sets
train_len = int(len(target_dataset) * train_size)
test_len = len(target_dataset) - train_len
train_target_dataset, test_target_dataset = random_split(target_dataset, [train_len, test_len])

# No labels needed, so we replace them with dummy labels
train_nolabel_dataset = [(img, -1) for img, _ in train_dataset]

train_target_loader = DataLoader(train_nolabel_dataset, batch_size=32, shuffle=True)
test_target_loader = DataLoader(test_target_dataset, batch_size=32, shuffle=False)

train_domain_adaptation(model, train_loader, train_target_loader, criterion, optimizer)


NameError: name 'transforms' is not defined