In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import numpy as np
import torch
import my_utils
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as  models
from torch.optim import Adam
#from torchvision.models import vit_b_16, vit_t_16, vit_s_16  # For vision transformers
from torch.optim import AdamW
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
#from vision_transformer_cp 
from vision_transformer_cp import DINOHead, vit_small, vit_tiny, vit_base, vit_tinyer, vit_tiniest



In [2]:
  # Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
# Hyperparameters
batch_size = 64
image_size = 32  # CIFAR-10 image size
n_epochs = 10
learning_rate = 1e-4
bag_size = 5 # Number of instances in each bag


In [4]:
class DataAugmentationDINO(object):
    def __init__(self, global_crops_scale, local_crops_scale, local_crops_number, image_size=224):
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        self.image_size = image_size

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],  # CIFAR-10 mean
                                 std=[0.2470, 0.2435, 0.2616])   # CIFAR-10 std
        ])

        # First global crop
        self.global_transfo1 = transforms.Compose([
            transforms.Resize(self.image_size),  # Ensure images are resized
            transforms.RandomResizedCrop(self.image_size, scale=global_crops_scale),
            flip_and_color_jitter,
            self.normalize,
        ])
        # Second global crop
        self.global_transfo2 = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.RandomResizedCrop(self.image_size, scale=global_crops_scale),
            flip_and_color_jitter,
            self.normalize,
        ])
        # Transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.RandomResizedCrop(self.image_size, scale=local_crops_scale),
            flip_and_color_jitter,
            self.normalize,
        ])

    def __call__(self, image):
        crops = []
        # Generate global crops
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        # Generate local crops
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        # Transform the original image to tensor
        unaugmented_image = transforms.Resize((self.image_size, self.image_size))(image)
        unaugmented_image = self.normalize(unaugmented_image)
       # print(f"Unaugmented image shape: {unaugmented_image.shape}")
        #for idx, crop in enumerate(crops):
           # print(f"Crop {idx} shape: {crop.shape}")
        return unaugmented_image, crops


In [5]:
# Parameters for data augmentation
global_crops_scale = (0.4, 1.0)
local_crops_scale = (0.05, 0.4)
local_crops_number = 4  # Number of local crops
image_size = 224  # Desired image size after resizing

In [6]:
#image_size=224
data_transform = DataAugmentationDINO(
    global_crops_scale=global_crops_scale,
    local_crops_scale=local_crops_scale,
    local_crops_number=local_crops_number,
    image_size=image_size
)

In [7]:
# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transform)


Files already downloaded and verified
Files already downloaded and verified


In [8]:
class BagDataset(Dataset):
    def __init__(self, dataset, bag_size=5):
        self.dataset = dataset
        self.bag_size = bag_size
        self.indices = list(range(len(self.dataset)))
        self.labels = [self.dataset.targets[i] for i in self.indices]
        self.num_classes = len(set(self.labels))
        print(f'Initializing BagDataset with {len(self.indices)} samples, bag size: {self.bag_size}')
    
    def __len__(self):
        return len(self.dataset) // self.bag_size
    
    def __getitem__(self, idx):
        # Get bag indices
        start_idx = idx * self.bag_size
        end_idx = start_idx + self.bag_size
        bag_indices = self.indices[start_idx:end_idx]
    
        bag = []
        labels = []
        for i in bag_indices:
            data, label = self.dataset[i]
            unaugmented_image, crops = data
            # Include both unaugmented image and crops
            bag.append(unaugmented_image)
            bag.extend(crops)
            labels.append(label)
    
        # Convert bag and labels to tensors
        bag = torch.stack(bag, dim=0)
        labels = torch.tensor(labels)
    
        # For MIL, assign the bag label (e.g., majority label or first label)
        bag_label = labels[0]
    
        return bag, bag_label


In [9]:
# Create BagDatasets
train_bag_dataset = BagDataset(train_dataset, bag_size=bag_size)
val_bag_dataset = BagDataset(val_dataset, bag_size=bag_size)

# Create DataLoaders
train_loader = DataLoader(train_bag_dataset, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_bag_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)


Initializing BagDataset with 50000 samples, bag size: 5
Initializing BagDataset with 10000 samples, bag size: 5


In [10]:
class DINO(nn.Module):
    def __init__(self, out_dim=500, use_bn=False, model_type="tiny"):
        super().__init__()
        model_map = {'tiny':vit_tiny(), 'small':vit_small(), 'base':vit_base(),
                     'vit_tinyer':vit_tinyer()
                    }
        # Student network
        self.student = model_map[model_type]
        embed_dim = self.student.embed_dim
        
        self.student = nn.Sequential(
            self.student,
            DINOHead(embed_dim, out_dim, use_bn)
        )
        # Teacher network
        self.teacher = model_map[model_type]
        self.teacher = nn.Sequential(
            self.teacher,
            DINOHead(embed_dim, out_dim, use_bn)
        )
        # Initialize teacher and student with same weights
        self.teacher.load_state_dict(self.student.state_dict())
        # Turn off gradients for teacher network
        for param in self.teacher.parameters():
            param.requires_grad = False

    def forward(self, x, is_teacher=False):
        # Forward pass through Perceiver
        batch_size = x.shape[0]
        x = self.student[0](x) if not is_teacher else self.teacher[0](x)
        # combine latents
        x = x.view(batch_size, 1, -1)
        x = self.student[1](x) if not is_teacher else self.teacher[1](x)
        return x

    def get_last_selfattention(self, x):
        return self.student[0].get_last_selfattention(x)

In [11]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']


In [12]:
class Classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, num_classes),
            #nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.classifier(x)


In [13]:
class InstanceClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(InstanceClassifier, self).__init__()
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        batch_size, num_instances, feature_dim = x.size()
        x = x.view(-1, feature_dim)
        predictions = self.classifier(x)
        predictions = predictions.view(batch_size, num_instances, -1)
        return predictions


In [14]:
# Gated Attention Mechanism
class GatedAttention(nn.Module):
    def __init__(self, input_dim):
        super(GatedAttention, self).__init__()
        self.M = input_dim
        self.L = 128
        self.ATTENTION_BRANCHES = 1

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(input_dim, self.M),
            nn.ReLU(),
        )

        self.attention_V = nn.Sequential(
            nn.Linear(self.M, self.L),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.M, self.L),
            nn.Sigmoid()
        )

        self.attention_w = nn.Linear(self.L, self.ATTENTION_BRANCHES)

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        batch_size, num_instances, feature_dim = x.size()
        x = x.view(batch_size * num_instances, feature_dim)

        # Apply feature_extractor_part2
        H = self.feature_extractor_part2(x)

        # Compute attention weights
        A_V = self.attention_V(H)
        A_U = self.attention_U(H)
        A = self.attention_w(A_V * A_U)

        # Reshape for softmax
        A = A.view(batch_size, num_instances, self.ATTENTION_BRANCHES)
        A = A.transpose(1, 2)
        A = F.softmax(A, dim=2)

        # Reshape H
        H = H.view(batch_size, num_instances, -1)

        # Compute bag representation
        Z = torch.bmm(A, H)

        if self.ATTENTION_BRANCHES == 1:
            Z = Z.squeeze(1)
        return Z, A


In [15]:
class MILModel(nn.Module):
    def __init__(self, feature_extractor, aggregator, classifier):
        super(MILModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.aggregator = aggregator
        self.classifier = classifier  # InstanceClassifier

    def forward(self, bag):
        batch_size, num_instances, channels, height, width = bag.shape
        bag = bag.view(-1, channels, height, width)
        features = self.feature_extractor(bag)
        features = features.view(batch_size, num_instances, -1)

        # Get instance-level predictions
        instance_predictions = self.classifier(features)  # Shape: (batch_size, num_instances, num_classes)

        # Get attention weights
        bag_representation, attention_weights = self.aggregator(features)  # attention_weights shape: (batch_size, num_attention_branches, num_instances)

        # Use attention weights to compute weighted sum of instance predictions
        attention_weights = attention_weights.squeeze(1)  # Shape: (batch_size, num_instances)
        attention_weights = attention_weights.unsqueeze(2)  # Shape: (batch_size, num_instances, 1)
        weighted_predictions = (instance_predictions * attention_weights).sum(dim=1)  # Shape: (batch_size, num_classes)

        return weighted_predictions, attention_weights,instance_predictions


In [16]:
import timm
# Number of classes in CIFAR-10
num_classes = 10

#model_name='vit_small_patch16_224_dino'

# Initialize model components
feature_extractor = DINO(out_dim=768, use_bn=False, model_type='base')
#model=models.resnet50(pretrained=True)

# Remove the classification head to get feature vectors
#feature_extractor.reset_classifier(0)
#feature_extractor = nn.Sequential(*list(model.children())[:-1]).to(device)
aggregator = GatedAttention(input_dim=768).to(device)
#classifier = Classifier(input_dim=2048, num_classes=num_classes).to(device)
classifier = InstanceClassifier(input_dim=768, num_classes=num_classes).to(device)
model = MILModel(feature_extractor, aggregator, classifier).to(device)


  from .autonotebook import tqdm as notebook_tqdm
  WeightNorm.apply(module, name, dim)
  return t.to(


In [17]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)


In [18]:
from torch.amp import autocast

scaler = torch.amp.GradScaler(device='cuda')

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct_train = 0
    total_train = 0

    for batch_idx, (bags, bag_labels) in enumerate(train_loader):
        bags = bags.to(device)
        bag_labels = bag_labels.to(device)
        print(f'Processing Batch {batch_idx + 1}/{len(train_loader)}')
        optimizer.zero_grad()
        with autocast(enabled=True,device_type='cuda'):
            outputs, attention_weights,_ = model(bags)
            loss = criterion(outputs, bag_labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        torch.cuda.empty_cache()  
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == bag_labels).sum().item()
        total_train += bag_labels.size(0)
        if (batch_idx + 1) % 100 == 0:
            print(f'Batch {batch_idx + 1}/{len(train_loader)}: Loss = {loss.item():.4f}')
    train_accuracy = 100 * correct_train / total_train
    average_loss = total_loss / len(train_loader)
    return train_accuracy, average_loss


In [19]:
from sklearn.metrics import f1_score

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    images_list = []
    attention_weights_list = []

    with torch.no_grad():
        for bags, bag_labels in val_loader:
            bags = bags.to(device)
            bag_labels = bag_labels.to(device)

            with autocast(enabled=True,device_type='cuda'):
                outputs, attention_weights, _ = model(bags)  # Unpack three values
                loss = criterion(outputs, bag_labels)

            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == bag_labels).sum().item()
            total += bag_labels.size(0)

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(bag_labels.cpu().numpy())
            images_list.extend(bags.cpu())
            attention_weights_list.extend(attention_weights.cpu())

    val_accuracy = 100 * correct / total
    average_loss = total_loss / len(val_loader)
    f1 = f1_score(all_targets, all_predictions, average='macro')

    print(f'Validation Loss: {average_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%, F1 Score: {f1:.4f}')
    return val_accuracy, images_list, all_predictions, all_targets, attention_weights_list


In [20]:
def visualize_attention(images_list, attention_weights_list, class_names, num_images=5, epoch=None):
    import matplotlib.pyplot as plt
    import numpy as np

    # Limit the number of images to display
    num_images = min(num_images, len(images_list))

    for idx in range(num_images):
        bag = images_list[idx]  # Shape: (num_instances, 3, 224, 224)
        attention_weights = attention_weights_list[idx]  # Shape: (num_instances, 1)

        # Ensure attention_weights is a 1D array
        attention_weights = attention_weights.squeeze().cpu().numpy()

        # Sort instances by attention weight in descending order
        sorted_indices = np.argsort(-attention_weights)
        top_instances = bag[sorted_indices][:5]  # Top 5 instances
        top_weights = attention_weights[sorted_indices][:5]

        plt.figure(figsize=(15, 3))
        for i in range(len(top_instances)):
            plt.subplot(1, 5, i+1)
            image = top_instances[i]
            image = image.permute(1, 2, 0).numpy()  # Convert from (C, H, W) to (H, W, C)
            # Un-normalize the image
            image = image * np.array([0.2470, 0.2435, 0.2616]) + np.array([0.4914, 0.4822, 0.4465])
            image = np.clip(image, 0, 1)
            plt.imshow(image)
            plt.title(f'Weight: {top_weights[i]:.2f}')
            plt.axis('off')
        plt.tight_layout()
        if epoch is not None:
            plt.savefig(f'attention_epoch_{epoch + 1}_bag_{idx + 1}.png')  # Save plot
        plt.show()
        plt.close()  # Close the figure to free memory


In [21]:
# Visualization function
def visualize_predictions(images_list, predictions, targets, class_names, num_images=10):
    import matplotlib.pyplot as plt
    import numpy as np

    # Limit the number of images to display
    num_images = min(num_images, len(images_list))

    plt.figure(figsize=(15, 6))
    for idx in range(num_images):
        plt.subplot(2, 5, idx+1)
        # images_list contains bags; extract one image per bag for visualization
        image = images_list[idx][0]  # Take the first image from the bag
        image = image.permute(1, 2, 0).numpy()  # Convert tensor to numpy array
        # Un-normalize the image
        image = image * np.array([0.2470, 0.2435, 0.2616]) + np.array([0.4914, 0.4822, 0.4465])
        image = np.clip(image, 0, 1)
        plt.imshow(image)
        pred_label = class_names[predictions[idx]]
        true_label = class_names[targets[idx]]
        plt.title(f"Predicted: {pred_label}\nActual: {true_label}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [22]:
#class_names = train_dataset.classes  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [23]:
# Training loop
for epoch in range(n_epochs):
    print(f'Epoch [{epoch + 1}/{n_epochs}]')
    train_accuracy, train_loss = train(model, train_loader, criterion, optimizer, device)
    print(f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%')

    val_accuracy, val_images_list, val_predictions, val_targets, attention_weights_list = validate(model, val_loader, criterion, device)    
    # Visualize predictions after validation
    #visualize_attention(val_images_list, attention_weights_list, class_names, num_images=5)
    visualize_predictions(val_images_list, val_predictions, val_targets, class_names, num_images=10)


Epoch [1/10]
Processing Batch 1/157
Processing Batch 2/157
Processing Batch 3/157
Processing Batch 4/157
Processing Batch 5/157
Processing Batch 6/157
Processing Batch 7/157
Processing Batch 8/157
Processing Batch 9/157
Processing Batch 10/157
Processing Batch 11/157
Processing Batch 12/157
Processing Batch 13/157
Processing Batch 14/157
Processing Batch 15/157
Processing Batch 16/157
Processing Batch 17/157
Processing Batch 18/157
Processing Batch 19/157
Processing Batch 20/157
Processing Batch 21/157
Processing Batch 22/157
Processing Batch 23/157
Processing Batch 24/157
Processing Batch 25/157
Processing Batch 26/157
Processing Batch 27/157
Processing Batch 28/157
Processing Batch 29/157
Processing Batch 30/157
Processing Batch 31/157
Processing Batch 32/157
Processing Batch 33/157
Processing Batch 34/157
Processing Batch 35/157
Processing Batch 36/157
Processing Batch 37/157
Processing Batch 38/157
Processing Batch 39/157
Processing Batch 40/157
Processing Batch 41/157
Processing B

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
