In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import StepLR
import copy
from sklearn.neighbors import KNeighborsClassifier

feature_dir = '/kaggle/working/files'
os.makedirs(feature_dir, exist_ok=True)




# Distillation Loss Function
def distillation_loss(student_outputs, teacher_outputs, temperature=2.0):
    teacher_probs = F.softmax(teacher_outputs / temperature, dim=1)
    student_probs = F.log_softmax(student_outputs / temperature, dim=1)
    return F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

# Custom Dataset Class
class CustomImageDataset(Dataset):
    def __init__(self, images, targets=None, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        target = self.targets[idx] if self.targets is not None else -1
        return image, target

    
# Load Data from .pth file
def load_data_from_pth(pth_path, has_targets=True):
    data_dict = torch.load(pth_path)
    data = data_dict['data']
    targets = data_dict['targets'] if has_targets else None
    return data, targets

# Paths to Datasets
#edit these paths to run on your machine
d1_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/1_train_data.tar.pth'
d1hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/1_eval_data.tar.pth'
d2_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/2_train_data.tar.pth'
d2hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/2_eval_data.tar.pth'
d3_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/3_train_data.tar.pth'
d3hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/3_eval_data.tar.pth'

d4_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/4_train_data.tar.pth'
d4hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/4_eval_data.tar.pth'

d5_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/5_train_data.tar.pth'
d5hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/5_eval_data.tar.pth'

d6_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/6_train_data.tar.pth'
d6hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/6_eval_data.tar.pth'

d7_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/7_train_data.tar.pth'
d7hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/7_eval_data.tar.pth'

d8_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/8_train_data.tar.pth'
d8hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/8_eval_data.tar.pth'

d9_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/9_train_data.tar.pth'
d9hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/9_eval_data.tar.pth'

d10_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/train_data/10_train_data.tar.pth'
d10hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_one_dataset/eval_data/10_eval_data.tar.pth'

d11_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/1_train_data.tar.pth'
d11hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/1_eval_data.tar.pth'

d12_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/2_train_data.tar.pth'
d12hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/2_eval_data.tar.pth'

d13_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/3_train_data.tar.pth'
d13hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/3_eval_data.tar.pth'

d14_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/4_train_data.tar.pth'
d14hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/4_eval_data.tar.pth'

d15_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/5_train_data.tar.pth'
d15hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/5_eval_data.tar.pth'

d16_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/6_train_data.tar.pth'
d16hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/6_eval_data.tar.pth'

d17_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/7_train_data.tar.pth'
d17hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/7_eval_data.tar.pth'

d18_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/8_train_data.tar.pth'
d18hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/8_eval_data.tar.pth'

d19_path = '//kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/9_train_data.tar.pth'
d19hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/9_eval_data.tar.pth'

d20_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/train_data/10_train_data.tar.pth'
d20hat_path = '/kaggle/input/dataset-for-proj-2/dataset/dataset/part_two_dataset/eval_data/10_eval_data.tar.pth'

print({"reached_path_to_datasets"})

# Load Datasets
d1_data, d1_targets = load_data_from_pth(d1_path)
d1hat_data, d1hat_targets = load_data_from_pth(d1hat_path)
d2_data, _ = load_data_from_pth(d2_path, has_targets=False)
d2hat_data, d2hat_targets = load_data_from_pth(d2hat_path)

d3_data, _ = load_data_from_pth(d3_path, has_targets=False)
d3hat_data, d3hat_targets = load_data_from_pth(d3hat_path)

d4_data, _ = load_data_from_pth(d4_path, has_targets=False)
d4hat_data, d4hat_targets = load_data_from_pth(d4hat_path)

d5_data, _ = load_data_from_pth(d5_path, has_targets=False)
d5hat_data, d5hat_targets = load_data_from_pth(d5hat_path)

d6_data, _ = load_data_from_pth(d6_path, has_targets=False)
d6hat_data, d6hat_targets = load_data_from_pth(d6hat_path)

d7_data, _ = load_data_from_pth(d7_path, has_targets=False)
d7hat_data, d7hat_targets = load_data_from_pth(d7hat_path)

d8_data, _ = load_data_from_pth(d8_path, has_targets=False)
d8hat_data, d8hat_targets = load_data_from_pth(d8hat_path)

d9_data, _ = load_data_from_pth(d9_path, has_targets=False)
d9hat_data, d9hat_targets = load_data_from_pth(d9hat_path)

d10_data, _ = load_data_from_pth(d10_path, has_targets=False)
d10hat_data, d10hat_targets = load_data_from_pth(d10hat_path)

d11_data, _ = load_data_from_pth(d11_path, has_targets=False)
d11hat_data, d11hat_targets = load_data_from_pth(d11hat_path)

d12_data, _ = load_data_from_pth(d12_path, has_targets=False)
d12hat_data, d12hat_targets = load_data_from_pth(d12hat_path)

d13_data, _ = load_data_from_pth(d13_path, has_targets=False)
d13hat_data, d13hat_targets = load_data_from_pth(d13hat_path)

d14_data, _ = load_data_from_pth(d14_path, has_targets=False)
d14hat_data, d14hat_targets = load_data_from_pth(d14hat_path)

d15_data, _ = load_data_from_pth(d15_path, has_targets=False)
d15hat_data, d15hat_targets = load_data_from_pth(d15hat_path)

d16_data, _ = load_data_from_pth(d16_path, has_targets=False)
d16hat_data, d16hat_targets = load_data_from_pth(d16hat_path)

d17_data, _ = load_data_from_pth(d17_path, has_targets=False)
d17hat_data, d17hat_targets = load_data_from_pth(d17hat_path)

d18_data, _ = load_data_from_pth(d18_path, has_targets=False)
d18hat_data, d18hat_targets = load_data_from_pth(d18hat_path)

d19_data, _ = load_data_from_pth(d19_path, has_targets=False)
d19hat_data, d19hat_targets = load_data_from_pth(d19hat_path)

d20_data, _ = load_data_from_pth(d20_path, has_targets=False)
d20hat_data, d20hat_targets = load_data_from_pth(d20hat_path)


# Data Transformations
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# DataLoaders
train_dataset = CustomImageDataset(d1_data, d1_targets, transform=train_transform)
d1hat_dataset = CustomImageDataset(d1hat_data, d1hat_targets, transform=test_transform)
d2hat_dataset = CustomImageDataset(d2hat_data, d2hat_targets, transform=test_transform)
d3hat_dataset = CustomImageDataset(d3hat_data, d3hat_targets, transform=test_transform)
d4hat_dataset = CustomImageDataset(d4hat_data, d4hat_targets, transform=test_transform)
d5hat_dataset = CustomImageDataset(d5hat_data, d5hat_targets, transform=test_transform)
d6hat_dataset = CustomImageDataset(d6hat_data, d6hat_targets, transform=test_transform)
d7hat_dataset = CustomImageDataset(d7hat_data, d7hat_targets, transform=test_transform)
d8hat_dataset = CustomImageDataset(d8hat_data, d8hat_targets, transform=test_transform)
d9hat_dataset = CustomImageDataset(d9hat_data, d9hat_targets, transform=test_transform)
d10hat_dataset = CustomImageDataset(d10hat_data, d10hat_targets, transform=test_transform)
d11hat_dataset = CustomImageDataset(d11hat_data, d11hat_targets, transform=test_transform)
d12hat_dataset = CustomImageDataset(d12hat_data, d12hat_targets, transform=test_transform)
d13hat_dataset = CustomImageDataset(d13hat_data, d13hat_targets, transform=test_transform)

# For 14
d14hat_dataset = CustomImageDataset(d14hat_data, d14hat_targets, transform=test_transform)
# For 15
d15hat_dataset = CustomImageDataset(d15hat_data, d15hat_targets, transform=test_transform)
# For 16
d16hat_dataset = CustomImageDataset(d16hat_data, d16hat_targets, transform=test_transform)
# For 17
d17hat_dataset = CustomImageDataset(d17hat_data, d17hat_targets, transform=test_transform)
# For 18
d18hat_dataset = CustomImageDataset(d18hat_data, d18hat_targets, transform=test_transform)
# For 19
d19hat_dataset = CustomImageDataset(d19hat_data, d19hat_targets, transform=test_transform)
# For 20
d20hat_dataset = CustomImageDataset(d20hat_data, d20hat_targets, transform=test_transform)



train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
d1hat_loader = DataLoader(d1hat_dataset, batch_size=32, shuffle=False)
d2hat_loader = DataLoader(d2hat_dataset, batch_size=32, shuffle=False)
d3hat_loader = DataLoader(d3hat_dataset, batch_size=32, shuffle=False)

d4hat_loader = DataLoader(d4hat_dataset, batch_size=32, shuffle=False)
d5hat_loader = DataLoader(d5hat_dataset, batch_size=32, shuffle=False)
d6hat_loader = DataLoader(d6hat_dataset, batch_size=32, shuffle=False)
d7hat_loader = DataLoader(d7hat_dataset, batch_size=32, shuffle=False)
d8hat_loader = DataLoader(d8hat_dataset, batch_size=32, shuffle=False)
d9hat_loader = DataLoader(d9hat_dataset, batch_size=32, shuffle=False)
d10hat_loader = DataLoader(d10hat_dataset, batch_size=32, shuffle=False)
d11hat_loader = DataLoader(d11hat_dataset, batch_size=32, shuffle=False)
d12hat_loader = DataLoader(d12hat_dataset, batch_size=32, shuffle=False)
d13hat_loader = DataLoader(d13hat_dataset, batch_size=32, shuffle=False)

# For 14
d14hat_loader = DataLoader(d14hat_dataset, batch_size=32, shuffle=False)
# For 15
d15hat_loader = DataLoader(d15hat_dataset, batch_size=32, shuffle=False)
# For 16
d16hat_loader = DataLoader(d16hat_dataset, batch_size=32, shuffle=False)
# For 17
d17hat_loader = DataLoader(d17hat_dataset, batch_size=32, shuffle=False)
# For 18
d18hat_loader = DataLoader(d18hat_dataset, batch_size=32, shuffle=False)
# For 19
d19hat_loader = DataLoader(d19hat_dataset, batch_size=32, shuffle=False)
# For 20
d20hat_loader = DataLoader(d20hat_dataset, batch_size=32, shuffle=False)


# Feature Extractor: ResNet50
class ResNet50Extractor(nn.Module):
    def __init__(self):
        super(ResNet50Extractor, self).__init__()
        resnet_model = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(resnet_model.children())[:-1])

    def forward(self, x):
        return self.features(x)
    
# Learning with Prototypes Model
class LearningWithPrototypes(nn.Module):
    def __init__(self, feature_extractor, num_classes, prototype_dim=2048):
        super(LearningWithPrototypes, self).__init__()
        self.feature_extractor = feature_extractor
        self.fc = nn.Linear(prototype_dim, num_classes)
        self.num_classes = num_classes
        self.prototypes = nn.Parameter(torch.randn(num_classes, prototype_dim))

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        logits = self.fc(features)
        return logits, features

    def prototype_loss(self, features, labels):
        distances = torch.cdist(features, self.prototypes)
        return F.cross_entropy(-distances, labels)
    
    def update_prototypes(self, features, labels):
        for class_idx in range(self.num_classes):
            class_features = features[labels == class_idx]
            if class_features.size(0) > 0:
                self.prototypes.data[class_idx] = class_features.mean(dim=0)

    def update_prototypes_with_ratp(self, features, labels, alpha=0.5):
        """
        Update prototypes with new high-confidence samples.

        Parameters:
        - features: Tensor of shape (batch_size, feature_dim), feature representations of the dataset.
        - labels: Tensor of shape (batch_size,), corresponding pseudo-labels.
        - alpha: Weight for blending old and new prototypes (default: 0.5).
        """
        for class_idx in range(self.num_classes):
            # Select features belonging to the current class
            class_mask = (labels == class_idx)
            class_features = features[class_mask]  # Shape: (num_samples_in_class, feature_dim)
            
            if class_features.size(0) > 0:
                # Compute new centroid for this class (average over batch dimension)
                new_centroid = class_features.mean(dim=0)  # Shape: (feature_dim,)
                
                # Blend the old and new prototypes
                self.prototypes.data[class_idx] = (
                    alpha * self.prototypes.data[class_idx] + (1 - alpha) * new_centroid
                )
                
                
                
# Initialize Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = ResNet50Extractor().to(device)
model = LearningWithPrototypes(feature_extractor, num_classes=len(set(d1_targets))).to(device)


print({"started"})

# Load Pretrained Prototypes if Available
if os.path.exists('/kaggle/input/weight-files/prototypes.pth'):
    model.prototypes.data = torch.load('/kaggle/input/weight-files/prototypes.pth')
    print("Loaded saved prototypes.")

# Loss Function, Optimizer, and Scheduler
criterion = nn.CrossEntropyLoss()
alpha = 0.1
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


# Initial Training on D1
if os.path.exists('/kaggle/input/weight-files/initial_model_weights.pth'):
    model.load_state_dict(torch.load('/kaggle/input/weight-files/initial_model_weights.pth'))
    print("Loaded pretrained model.")
else:
    model.train()
    for epoch in range(17):
        print({epoch})
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, features = model(images)
            ce_loss = criterion(outputs, labels)
            proto_loss = model.prototype_loss(features, labels)
            loss = ce_loss + alpha * proto_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        scheduler.step()
        with torch.no_grad():
            all_features = []
            all_labels = []
            for images, labels in train_loader:
                images = images.to(device)
                features = model.feature_extractor(images)
                all_features.append(features.view(features.size(0), -1))
                all_labels.append(labels.to(device))
            all_features = torch.cat(all_features, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            model.update_prototypes(all_features, all_labels)

    torch.save(model.state_dict(), '/kaggle/working/files/initial_model_weights.pth')
    torch.save(model.prototypes.data, '/kaggle/working/files/prototypes.pth')
    print("Saved initial model and prototypes.")
   



    

# Self-Training with Distillation
# def self_train_update(model, unlabeled_data,num, num_epochs=12, temperature=2.0, alpha=0.5, beta=0.5):
#     pseudo_labels = []
#     model.eval()
#     with torch.no_grad():
#         for images, _ in DataLoader(CustomImageDataset(unlabeled_data, transform=test_transform), batch_size=32):
#             images = images.to(device)
#             outputs, _ = model(images)
#             _, predicted = torch.max(outputs, 1)
#             pseudo_labels.extend(predicted.cpu().numpy())

#     pseudo_dataset = CustomImageDataset(unlabeled_data, np.array(pseudo_labels), transform=train_transform)
#     pseudo_loader = DataLoader(pseudo_dataset, batch_size=32, shuffle=True)
#     new_model = copy.deepcopy(model)
#     optimizer = optim.AdamW(new_model.parameters(), lr=0.00001, weight_decay=0.01)
#     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

#     for epoch in range(num_epochs):
#         print({epoch})
#         for images, labels in pseudo_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             student_outputs, features = new_model(images)
#             with torch.no_grad():
#                 teacher_outputs, _ = model(images)
#             ce_loss = criterion(student_outputs, labels)
#             proto_loss = new_model.prototype_loss(features, labels)
#             dist_loss = distillation_loss(student_outputs, teacher_outputs, temperature)
#             loss = ce_loss + alpha * proto_loss + beta * dist_loss
#             loss.backward()
#             optimizer.step()
#         scheduler.step()

#     torch.save(new_model.state_dict(), f'/kaggle/working/files/model_f{num}_weights.pth')
#     print("Saved model after self-training.")
#     return new_model



def self_train_update(
    old_model,
    unlabeled_data,
    num,
    num_epochs=15,
    temperature=2.0,
    alpha=0.5,
    beta=0.5
):
    """
    Self-training with pseudo-labels, distillation, and prototype updating.

    Parameters:
    - old_model: The model from the previous step, whose prototypes will be blended.
    - unlabeled_data: Unlabeled dataset to be pseudo-labeled and used for training.
    - num: Current dataset number for tracking.
    - num_epochs: Number of epochs for fine-tuning.
    - temperature: Temperature parameter for distillation.
    - alpha: Weight for the prototype loss.
    - beta: Weight for the distillation loss.

    Returns:
    - new_model: Updated model after training.
    """
    pseudo_labels = []
    old_model.eval()
    
    # Generate pseudo-labels using the old model
    with torch.no_grad():
        for images, _ in DataLoader(CustomImageDataset(unlabeled_data, transform=test_transform), batch_size=32):
            images = images.to(device)
            outputs, _ = old_model(images)
            _, predicted = torch.max(outputs, 1)
            pseudo_labels.extend(predicted.cpu().numpy())

    # Create a pseudo-labeled dataset
    pseudo_dataset = CustomImageDataset(unlabeled_data, np.array(pseudo_labels), transform=train_transform)
    pseudo_loader = DataLoader(pseudo_dataset, batch_size=32, shuffle=True)

    # Initialize a new model as a copy of the old one
    new_model = copy.deepcopy(old_model)
    optimizer = optim.AdamW(new_model.parameters(), lr=0.00001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Train the new model with the pseudo-labeled data
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        for images, labels in pseudo_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass with the student (new) model
            student_outputs, features = new_model(images)

            # Get teacher outputs for distillation
            with torch.no_grad():
                teacher_outputs, _ = old_model(images)

            # Calculate losses
            ce_loss = criterion(student_outputs, labels)
            proto_loss = new_model.prototype_loss(features, labels)
            dist_loss = distillation_loss(student_outputs, teacher_outputs, temperature)

            # Combined loss
            loss = ce_loss + alpha * proto_loss + beta * dist_loss

            # Backpropagation
            loss.backward()
            optimizer.step()
        scheduler.step()

    # Save the new model after training
    torch.save(new_model.state_dict(), f'/kaggle/working/files/model_f{num}_weights.pth')
    print(f"Saved model after self-training on dataset {num}.")

    # Update prototypes using weighted average of old and new model prototypes
    with torch.no_grad():
        weight_old = (num - 1) / num
        weight_new = 1 / num
        new_model.prototypes.data = (
            weight_old * old_model.prototypes.data + weight_new * new_model.prototypes.data
        )
        print("Prototypes updated using weighted average.")

    return new_model

# Function to Calculate Accuracy
def calculate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs, _ = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total


accuracy_matrix = []


f1_acc = calculate_accuracy(model,d1hat_loader)
accuracy_matrix.append([f1_acc,None,None,None,None,None,None,None,None,None])

print({f1_acc})

print("hello")
#Train on Unlabeled Data (D2) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f2_weights.pth'):
    model_f2 = copy.deepcopy(model)
    model_f2.load_state_dict(torch.load('/kaggle/input/weight-files/model_f2_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f2 = self_train_update(model, d2_data,2)
    

print("started_accuracy_computation")
f2_acc_d1hat = calculate_accuracy(model_f2,d1hat_loader)
f2_acc_d2hat = calculate_accuracy(model_f2,d2hat_loader)

accuracy_matrix.append([f2_acc_d1hat,f2_acc_d2hat,None,None,None,None,None,None,None,None])
print({f2_acc_d1hat})
print({f2_acc_d2hat})


# Train on Unlabeled Data (D3) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f3_weights.pth'):
    model_f3 = copy.deepcopy(model)
    model_f3.load_state_dict(torch.load('/kaggle/input/weight-files/model_f3_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f3 = self_train_update(model_f2, d3_data,3)
    
print("started_accuracy_computation")
f3_acc_d1hat = calculate_accuracy(model_f3,d1hat_loader)
f3_acc_d2hat = calculate_accuracy(model_f3,d2hat_loader)
f3_acc_d3hat = calculate_accuracy(model_f3,d3hat_loader)

accuracy_matrix.append([f3_acc_d1hat,f3_acc_d2hat,f3_acc_d3hat,None,None,None,None,None,None,None])
print({f3_acc_d1hat})
print({f3_acc_d2hat})
print({f3_acc_d3hat})


# Train on Unlabeled Data (D4) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f4_weights.pth'):
    model_f4 = copy.deepcopy(model)
    model_f4.load_state_dict(torch.load('/kaggle/input/weight-files/model_f4_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f4 = self_train_update(model_f3, d4_data,4)
    
print("started_accuracy_computation")
f4_acc_d1hat = calculate_accuracy(model_f4,d1hat_loader)
f4_acc_d2hat = calculate_accuracy(model_f4,d2hat_loader)
f4_acc_d3hat = calculate_accuracy(model_f4,d3hat_loader)
f4_acc_d4hat = calculate_accuracy(model_f4,d4hat_loader)

accuracy_matrix.append([f4_acc_d1hat,f4_acc_d2hat,f4_acc_d3hat,f4_acc_d4hat,None,None,None,None,None,None])
print({f4_acc_d1hat})
print({f4_acc_d2hat})
print({f4_acc_d3hat})
print({f4_acc_d4hat})


# Train on Unlabeled Data (D5) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f5_weights.pth'):
    model_f5 = copy.deepcopy(model)
    model_f5.load_state_dict(torch.load('/kaggle/input/weight-files/model_f5_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f5 = self_train_update(model_f4, d5_data,5)
    
print("started_accuracy_computation")
f5_acc_d1hat = calculate_accuracy(model_f5,d1hat_loader)
f5_acc_d2hat = calculate_accuracy(model_f5,d2hat_loader)
f5_acc_d3hat = calculate_accuracy(model_f5,d3hat_loader)
f5_acc_d4hat = calculate_accuracy(model_f5,d4hat_loader)
f5_acc_d5hat = calculate_accuracy(model_f5,d5hat_loader)

accuracy_matrix.append([f5_acc_d1hat,f5_acc_d2hat,f5_acc_d3hat,f5_acc_d4hat,f5_acc_d5hat,None,None,None,None,None])
print({f5_acc_d1hat})
print({f5_acc_d2hat})
print({f5_acc_d3hat})
print({f5_acc_d4hat})
print({f5_acc_d5hat})



# Train on Unlabeled Data (D6) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f6_weights.pth'):
    model_f6 = copy.deepcopy(model)
    model_f6.load_state_dict(torch.load('/kaggle/input/weight-files/model_f6_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f6 = self_train_update(model_f5, d6_data,6)
    
print("started_accuracy_computation")
f6_acc_d1hat = calculate_accuracy(model_f6,d1hat_loader)
f6_acc_d2hat = calculate_accuracy(model_f6,d2hat_loader)
f6_acc_d3hat = calculate_accuracy(model_f6,d3hat_loader)
f6_acc_d4hat = calculate_accuracy(model_f6,d4hat_loader)
f6_acc_d5hat = calculate_accuracy(model_f6,d5hat_loader)
f6_acc_d6hat = calculate_accuracy(model_f6,d6hat_loader)

accuracy_matrix.append([f6_acc_d1hat,f6_acc_d2hat,f6_acc_d3hat,f6_acc_d4hat,f6_acc_d5hat,f6_acc_d6hat,None,None,None,None])
print({f6_acc_d1hat})
print({f6_acc_d2hat})
print({f6_acc_d3hat})
print({f6_acc_d4hat})
print({f6_acc_d5hat})
print({f6_acc_d6hat})


# Train on Unlabeled Data (D7) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f7_weights.pth'):
    model_f7 = copy.deepcopy(model)
    model_f7.load_state_dict(torch.load('/kaggle/input/weight-files/model_f7_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f7 = self_train_update(model_f6, d7_data,7)
    
print("started_accuracy_computation")
f7_acc_d1hat = calculate_accuracy(model_f7,d1hat_loader)
f7_acc_d2hat = calculate_accuracy(model_f7,d2hat_loader)
f7_acc_d3hat = calculate_accuracy(model_f7,d3hat_loader)
f7_acc_d4hat = calculate_accuracy(model_f7,d4hat_loader)
f7_acc_d5hat = calculate_accuracy(model_f7,d5hat_loader)
f7_acc_d6hat = calculate_accuracy(model_f7,d6hat_loader)
f7_acc_d7hat = calculate_accuracy(model_f7,d7hat_loader)

accuracy_matrix.append([f7_acc_d1hat,f7_acc_d2hat,f7_acc_d3hat,f7_acc_d4hat,f7_acc_d5hat,f7_acc_d6hat,f7_acc_d7hat,None,None,None])
print({f7_acc_d1hat})
print({f7_acc_d2hat})
print({f7_acc_d3hat})
print({f7_acc_d4hat})
print({f7_acc_d5hat})
print({f7_acc_d6hat})
print({f7_acc_d7hat})


# Train on Unlabeled Data (D8) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f8_weights.pth'):
    model_f8 = copy.deepcopy(model)
    model_f8.load_state_dict(torch.load('/kaggle/input/weight-files/model_f8_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f8 = self_train_update(model_f7, d8_data,8)
    
print("started_accuracy_computation")
f8_acc_d1hat = calculate_accuracy(model_f8,d1hat_loader)
f8_acc_d2hat = calculate_accuracy(model_f8,d2hat_loader)
f8_acc_d3hat = calculate_accuracy(model_f8,d3hat_loader)
f8_acc_d4hat = calculate_accuracy(model_f8,d4hat_loader)
f8_acc_d5hat = calculate_accuracy(model_f8,d5hat_loader)
f8_acc_d6hat = calculate_accuracy(model_f8,d6hat_loader)
f8_acc_d7hat = calculate_accuracy(model_f8,d7hat_loader)
f8_acc_d8hat = calculate_accuracy(model_f8,d8hat_loader)

accuracy_matrix.append([f8_acc_d1hat,f8_acc_d2hat,f8_acc_d3hat,f8_acc_d4hat,f8_acc_d5hat,f8_acc_d6hat,f8_acc_d7hat,f8_acc_d8hat,None,None])
print({f8_acc_d1hat})
print({f8_acc_d2hat})
print({f8_acc_d3hat})
print({f8_acc_d4hat})
print({f8_acc_d5hat})
print({f8_acc_d6hat})
print({f8_acc_d7hat})
print({f8_acc_d8hat})


# Train on Unlabeled Data (D9) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f9_weights.pth'):
    model_f9 = copy.deepcopy(model)
    model_f9.load_state_dict(torch.load('/kaggle/input/weight-files/model_f9_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f9 = self_train_update(model_f8, d9_data,9)
    
print("started_accuracy_computation")
f9_acc_d1hat = calculate_accuracy(model_f9,d1hat_loader)
f9_acc_d2hat = calculate_accuracy(model_f9,d2hat_loader)
f9_acc_d3hat = calculate_accuracy(model_f9,d3hat_loader)
f9_acc_d4hat = calculate_accuracy(model_f9,d4hat_loader)
f9_acc_d5hat = calculate_accuracy(model_f9,d5hat_loader)
f9_acc_d6hat = calculate_accuracy(model_f8,d6hat_loader)
f9_acc_d7hat = calculate_accuracy(model_f9,d7hat_loader)
f9_acc_d8hat = calculate_accuracy(model_f9,d8hat_loader)
f9_acc_d9hat = calculate_accuracy(model_f9,d9hat_loader)

accuracy_matrix.append([f9_acc_d1hat,f9_acc_d2hat,f9_acc_d3hat,f9_acc_d4hat,f9_acc_d5hat,f9_acc_d6hat,f9_acc_d7hat,f9_acc_d8hat,f9_acc_d9hat,None])
print({f9_acc_d1hat})
print({f9_acc_d2hat})
print({f9_acc_d3hat})
print({f9_acc_d4hat})
print({f9_acc_d5hat})
print({f9_acc_d6hat})
print({f9_acc_d7hat})
print({f9_acc_d8hat})
print({f9_acc_d9hat})



# Train on Unlabeled Data (D10) using self-training and distillation
if os.path.exists('/kaggle/input/weight-files/model_f10_weights.pth'):
    model_f10 = copy.deepcopy(model)
    model_f10.load_state_dict(torch.load('/kaggle/input/weight-files/model_f10_weights.pth'))
    print("Loaded self-trained model.")
else:
    model_f10 = self_train_update(model_f9, d10_data,10)
    
print("started_accuracy_computation")
f10_acc_d1hat = calculate_accuracy(model_f10,d1hat_loader)
print({f10_acc_d1hat})
f10_acc_d2hat = calculate_accuracy(model_f10,d2hat_loader)
print({f10_acc_d2hat})
f10_acc_d3hat = calculate_accuracy(model_f10,d3hat_loader)
print({f10_acc_d3hat})
f10_acc_d4hat = calculate_accuracy(model_f10,d4hat_loader)
print({f10_acc_d4hat})
f10_acc_d5hat = calculate_accuracy(model_f10,d5hat_loader)
print({f10_acc_d5hat})
f10_acc_d6hat = calculate_accuracy(model_f10,d6hat_loader)
print({f10_acc_d6hat})
f10_acc_d7hat = calculate_accuracy(model_f10,d7hat_loader)
print({f10_acc_d7hat})
f10_acc_d8hat = calculate_accuracy(model_f10,d8hat_loader)
print({f10_acc_d8hat})
f10_acc_d9hat = calculate_accuracy(model_f10,d9hat_loader)
print({f10_acc_d9hat})
f10_acc_d10hat = calculate_accuracy(model_f10,d10hat_loader)
print({f10_acc_d10hat})

accuracy_matrix.append([f10_acc_d1hat,f10_acc_d2hat,f10_acc_d3hat,f10_acc_d4hat,f10_acc_d5hat,f10_acc_d6hat,f10_acc_d7hat,f10_acc_d8hat,f10_acc_d9hat,f10_acc_d10hat])





{'reached_path_to_datasets'}


  data_dict = torch.load(pth_path)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 173MB/s] 


{'started'}
{0}
{1}
{2}
{3}
{4}
{5}
{6}
{7}
{8}
{9}
{10}
{11}
{12}
{13}
{14}
{15}
{16}
Saved initial model and prototypes.
{91.52}
hello
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Saved model after self-training on dataset 2.
Prototypes updated using weighted average.
started_accuracy_computation
{92.2}
{91.24}
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Saved model after self-training on dataset 3.
Prototypes updated using weighted average.
started_accuracy_computation
{91.72}
{91.56}
{90.84}
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Saved model after self-training on dataset 4.
Prototypes updated using weig

Rand Mix using Ada In 