In [None]:
!pip install torch torchvision tqdm pillow

In [None]:
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image
from tqdm import tqdm

class FeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super(FeatureExtractor, self).__init__()
        # Using ResNet as base feature extractor
        resnet = models.resnet50(pretrained=pretrained)
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
    def forward(self, x):
        return self.features(x)

class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(AttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels // 8, in_channels, kernel_size=1)
        
    def forward(self, x):
        attention = F.relu(self.conv1(x))
        attention = torch.sigmoid(self.conv2(attention))
        return x * attention

class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        # Adjust the architecture to handle the feature map size from ResNet50
        # ResNet50 outputs 2048 x 7 x 7 for 224x224 input
        self.conv1 = nn.Conv2d(input_channels, 512, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        # Adjust the final linear layer input size based on the conv output
        self.fc = nn.Linear(128 * 7 * 7, 1)
        
        # Add batch normalization for better training stability
        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(128)
        
    def forward(self, x):
        # Add shape checks for debugging
        x = self.bn1(F.leaky_relu(self.conv1(x), 0.2))
        x = self.bn2(F.leaky_relu(self.conv2(x), 0.2))
        x = self.bn3(F.leaky_relu(self.conv3(x), 0.2))
        x = x.view(x.size(0), -1)  # Flatten
        return torch.sigmoid(self.fc(x))

class MTUNetPlusPlus(nn.Module):
    def __init__(self, num_classes, feature_dim=2048):
        super(MTUNetPlusPlus, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.attention = AttentionModule(feature_dim)
        
        # Prototype learning
        self.prototype_vectors = nn.Parameter(torch.randn(num_classes, feature_dim))
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x, return_features=False):
        # Extract features
        features = self.feature_extractor(x)
        
        # Apply attention
        attended_features = self.attention(features)
        
        # Global average pooling
        pooled_features = F.adaptive_avg_pool2d(attended_features, (1, 1))
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        
        # Prototype matching
        prototype_distances = torch.cdist(pooled_features, self.prototype_vectors)
        
        # Classification
        logits = self.classifier(pooled_features)
        
        if return_features:
            return logits, pooled_features, prototype_distances, attended_features
        return logits

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.features = None
        
        # Register hooks
        target_layer.register_forward_hook(self._save_features)
        target_layer.register_backward_hook(self._save_gradients)
    
    def _save_features(self, module, input, output):
        self.features = output
    
    def _save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
    
    def generate_cam(self, input_image, target_class):
        # Forward pass
        model_output = self.model(input_image)
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass for target class
        one_hot = torch.zeros_like(model_output)
        one_hot[0][target_class] = 1
        model_output.backward(gradient=one_hot, retain_graph=True)
        
        # Generate CAM
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        for i in range(self.features.shape[1]):
            self.features[:, i, :, :] *= pooled_gradients[i]
        
        cam = torch.mean(self.features, dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=input_image.shape[2:], mode='bilinear', align_corners=False)
        
        # Normalize
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        
        return cam

def train_step(model, discriminator, batch, optimizer_G, optimizer_D):
    # Training logic for one step
    real_images, labels = batch
    
    # Train Discriminator
    optimizer_D.zero_grad()
    features = model.feature_extractor(real_images)
    d_loss_real = F.binary_cross_entropy(discriminator(features), torch.ones_like(features))
    d_loss_real.backward()
    optimizer_D.step()
    
    # Train Generator (Feature Extractor)
    optimizer_G.zero_grad()
    features = model.feature_extractor(real_images)
    g_loss = F.binary_cross_entropy(discriminator(features), torch.ones_like(features))
    
    # Classification loss
    logits = model(real_images)
    cls_loss = F.cross_entropy(logits, labels)
    
    # Combined loss
    total_loss = cls_loss + 0.1 * g_loss
    total_loss.backward()
    optimizer_G.step()
    
    return total_loss.item()

class HAM10000Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images organized in class folders
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))  # Get class folders
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        # Load all image paths and labels
        for class_name in self.classes:
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    if img_name.endswith(('.jpg', '.jpeg', '.png')):
                        self.images.append(os.path.join(class_path, img_name))
                        self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def train_epoch(model, discriminator, train_loader, optimizer_G, optimizer_D, device):
    model.train()
    discriminator.train()
    total_loss = 0
    
    for batch_idx, (images, labels) in enumerate(tqdm(train_loader)):
        images, labels = images.to(device), labels.to(device)
        batch_size = images.size(0)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        features = model.feature_extractor(images)
        d_real = discriminator(features.detach())
        # Use proper target shape
        real_labels = torch.ones(batch_size, 1).to(device)
        d_loss_real = F.binary_cross_entropy(d_real, real_labels)
        d_loss_real.backward()
        optimizer_D.step()
        
        # Train Generator (Feature Extractor) and Classifier
        optimizer_G.zero_grad()
        features = model.feature_extractor(images)
        d_fake = discriminator(features)
        g_loss = F.binary_cross_entropy(d_fake, real_labels)
        
        # Classification loss
        logits = model(images)
        cls_loss = F.cross_entropy(logits, labels)
        
        # Combined loss
        total_g_loss = cls_loss + 0.1 * g_loss
        total_g_loss.backward()
        optimizer_G.step()
        
        total_loss += total_g_loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Batch [{batch_idx}/{len(train_loader)}], '
                  f'Loss: {total_g_loss.item():.4f}, '
                  f'Class Loss: {cls_loss.item():.4f}, '
                  f'G Loss: {g_loss.item():.4f}')
    
    return total_loss / len(train_loader)

def evaluate_episodes(model, dataset, num_episodes=2000, n_way=2, k_shot=1, device='cuda'):
    """
    Evaluate model on n-way k-shot tasks
    Args:
        model: trained model
        dataset: dataset to sample episodes from
        num_episodes: number of episodes to evaluate
        n_way: number of classes per episode (2 for 2-way tasks)
        k_shot: number of support examples per class
        device: device to run evaluation on
    """
    model.eval()
    accuracies = []
    
    # Get all available classes
    all_classes = list(set(dataset.labels))
    
    with torch.no_grad():
        for episode in range(num_episodes):
            # Randomly sample n classes
            episode_classes = random.sample(all_classes, n_way)
            
            # Get indices for each class
            support_indices = []
            query_indices = []
            
            for class_idx in episode_classes:
                class_indices = [i for i, label in enumerate(dataset.labels) if label == class_idx]
                # Sample k examples for support set
                support = random.sample(class_indices, k_shot)
                # Sample 1 example for query set (from remaining examples)
                remaining = list(set(class_indices) - set(support))
                query = random.sample(remaining, 1)
                
                support_indices.extend(support)
                query_indices.extend(query)
            
            # Prepare support and query sets
            support_images = torch.stack([dataset[idx][0] for idx in support_indices]).to(device)
            support_labels = torch.tensor([dataset[idx][1] for idx in support_indices]).to(device)
            query_images = torch.stack([dataset[idx][0] for idx in query_indices]).to(device)
            query_labels = torch.tensor([dataset[idx][1] for idx in query_indices]).to(device)
            
            # Get model predictions
            support_features = model(support_images, return_features=True)[1]  # Get pooled features
            query_features = model(query_images, return_features=True)[1]
            
            # Calculate prototypes for each class
            prototypes = {}
            for cls in episode_classes:
                cls_mask = support_labels == cls
                cls_features = support_features[cls_mask]
                prototypes[cls] = cls_features.mean(0)
            
            # Calculate distances to prototypes
            accuracies_episode = []
            for i, query_feat in enumerate(query_features):
                distances = {cls: torch.norm(query_feat - proto) for cls, proto in prototypes.items()}
                predicted_cls = min(distances, key=distances.get)
                correct = (predicted_cls == query_labels[i].item())
                accuracies_episode.append(correct)
            
            # Calculate accuracy for this episode
            accuracy = sum(accuracies_episode) / len(accuracies_episode)
            accuracies.append(accuracy)
            
            if (episode + 1) % 100 == 0:
                print(f'Episode {episode + 1}/{num_episodes}, Running Avg Accuracy: {np.mean(accuracies):.4f}')
    
    final_accuracy = np.mean(accuracies)
    print(f'\nFinal Average Accuracy over {num_episodes} episodes: {final_accuracy:.4f}')
    return final_accuracy

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset and dataloader
    dataset = HAM10000Dataset(
        root_dir='/kaggle/input/ham10000-and-gan/synthetic_images',
        transform=transform
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )

    # Initialize models
    model = MTUNetPlusPlus(num_classes=7).to(device)
    discriminator = Discriminator(input_channels=2048).to(device)

    # Initialize optimizers
    optimizer_G = torch.optim.Adam(model.parameters(), lr=0.0001)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

    # Training loop
    num_epochs = 50
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        avg_loss = train_epoch(
            model, discriminator, train_loader,
            optimizer_G, optimizer_D, device
        )
        print(f'Average loss: {avg_loss:.4f}')
        
        # Evaluate episodes after each epoch
        print("\nEvaluating 2-way tasks...")
        avg_accuracy = evaluate_episodes(
            model,
            dataset,
            num_episodes=2000,
            n_way=2,
            k_shot=1,
            device=device
        )
        print(f'Average accuracy over 2000 episodes: {avg_accuracy:.4f}')

        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'episode_accuracy': avg_accuracy
            }, f'checkpoint_epoch_{epoch+1}.pth')

if __name__ == "__main__":
    main()