In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# L = nn.Parameter(torch.randn(4, 2, requires_grad=True))
# U = nn.Parameter(torch.randn(2, 3, requires_grad=True))

# W = torch.sigmoid(50 * (L @ U))  # Should allow grad
# x = torch.randn(1, 3)
# target = torch.randn(1, 4)

# out = F.linear(x, W)
# loss = F.mse_loss(out, target)
# loss.backward()

# print(L.grad)
# print(U.grad)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# 1. Feature Extractor (ResNet-18)
# ----------------------------
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x):
        x = self.feature_extractor(x)
        return x.view(x.size(0), -1)

# ----------------------------
# 2. Gating Network
# ----------------------------
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, num_experts)

    # def forward(self, x):
    #     logits = self.fc(x)
    #     # gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
    #     expert_scores = F.softmax((logits), dim=1)
    #     expert_id = torch.argmax(expert_scores, dim=1)
    #     return expert_id, expert_scores

    def forward(self, x):
      logits = self.fc(x)
      gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
      expert_scores = F.softmax((logits + gumbel_noise) / 0.5, dim=-1)
      # print(expert_scores)
      # print("expert_scores: ",expert_scores)
      return expert_scores  # shape [B, num_experts]

# ----------------------------
# 2. Masked Feature Extractor
# ----------------------------
class MaskedLUFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, threshold=0.5, num_layers=4):
        super(MaskedLUFeatureExtractor, self).__init__()
        self.num_experts = num_experts
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.threshold = threshold

        # Create real linear layers (separate per expert and layer)
        self.linear_layers = nn.ModuleDict()

        # Create L, U, and b for mask generation
        self.L_matrices = nn.ParameterDict()
        self.U_matrices = nn.ParameterDict()
        self.bias_masks = nn.ParameterDict()

        for expert_id in range(num_experts):
            for layer_idx in range(num_layers):
                d_in = input_dim if layer_idx == 0 else hidden_dim
                d_out = hidden_dim
                k = max(d_in, d_out)

                # Initialize linear layer for this expert and layer
                lin_key = f"expert{expert_id}_layer{layer_idx}_linear"
                self.linear_layers[lin_key] = nn.Linear(d_in, d_out)

                # L, U, and bias mask
                l_key = f"expert{expert_id}_layer{layer_idx}_L"
                u_key = f"expert{expert_id}_layer{layer_idx}_U"
                b_key = f"expert{expert_id}_layer{layer_idx}_b"

                self.L_matrices[l_key] = nn.Parameter(torch.tril(torch.randn(d_out, k)))
                self.U_matrices[u_key] = nn.Parameter(torch.randn(k, d_in))
                self.bias_masks[b_key] = nn.Parameter(torch.zeros(d_out))

    def forward(self, x, expert_id):
        for i in range(self.num_layers):
            lin_key = f"expert{expert_id}_layer{i}_linear"
            l_key = f"expert{expert_id}_layer{i}_L"
            u_key = f"expert{expert_id}_layer{i}_U"
            b_key = f"expert{expert_id}_layer{i}_b"

            linear = self.linear_layers[lin_key]
            weight = linear.weight         # [d_out, d_in]
            bias = linear.bias             # [d_out]

            # Generate mask using sigmoid(50 * L @ U)
            L = self.L_matrices[l_key]     # [d_out, k]
            U = self.U_matrices[u_key]     # [k, d_in]
            b_mask = self.bias_masks[b_key]  # [d_out]

            W_mask = torch.sigmoid(10 * (L @ U))
            # print(W_mask)          # [d_out, d_in]
            b_mask = torch.sigmoid(b_mask)                  # [d_out]

            # Apply mask
            masked_weight = weight * W_mask                 # [d_out, d_in]
            masked_bias = bias * b_mask                     # [d_out]

            # Linear transformation with masked weights and biases
            x = F.linear(x, masked_weight, masked_bias)
            x = F.relu(x)

        return x

# class MaskedLUFeatureExtractor(nn.Module):
#     def __init__(self, input_dim, hidden_dim, num_experts, num_layers=4):
#         super(MaskedLUFeatureExtractor, self).__init__()
#         self.num_experts = num_experts
#         self.num_layers = num_layers
#         self.hidden_dim = hidden_dim
#         self.input_dim = input_dim

#         self.linear_layers = nn.ModuleDict()

#         # Hardcoded masks per expert and layer
#         self.weight_masks = {}  # shape: [d_out, d_in]
#         self.bias_masks = {}    # shape: [d_out]

#         for expert_id in range(num_experts):
#             self.weight_masks[expert_id] = {}
#             self.bias_masks[expert_id] = {}
#             for layer_idx in range(num_layers):
#                 d_in = input_dim if layer_idx == 0 else hidden_dim
#                 d_out = hidden_dim

#                 lin_key = f"expert{expert_id}_layer{layer_idx}_linear"
#                 self.linear_layers[lin_key] = nn.Linear(d_in, d_out)

#                 # Example: Mask half of the weights/biases randomly
#                 torch.manual_seed(expert_id * 10 + layer_idx)  # for reproducibility
#                 # self.weight_masks[expert_id][layer_idx] = (torch.rand(d_out, d_in) > 0.5).float()
#                 # # you are doing (dout_din) because PyTorch internally does (batch_size *x) * W^T, during matrix multiplication
#                 # self.bias_masks[expert_id][layer_idx] = (torch.rand(d_out) > 0.5).float()

#                 self.weight_masks[expert_id][layer_idx] = torch.ones(d_out, d_in)
#                 self.bias_masks[expert_id][layer_idx] = torch.ones(d_out)

#     def forward(self, x, expert_id):
#         for i in range(self.num_layers):
#             lin_key = f"expert{expert_id}_layer{i}_linear"
#             linear = self.linear_layers[lin_key]
#             weight = linear.weight
#             bias = linear.bias

#             W_mask = self.weight_masks[expert_id][i].to(weight.device)
#             b_mask = self.bias_masks[expert_id][i].to(bias.device)

#             masked_weight = weight * W_mask
#             masked_bias = bias * b_mask
#             # print(masked_weight.size())
#             # print(x.size())
#             x = F.linear(x, masked_weight, masked_bias)
#             x = F.relu(x)

#         return x

# ----------------------------
# 4. Shared Classification Layer
# ----------------------------
class SharedClassifier(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(SharedClassifier, self).__init__()
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
        # return F.softmax(self.fc(x), dim=1)

# ----------------------------
# 5. Mixture of Experts
# ----------------------------
class MixtureOfExperts(nn.Module):
    def __init__(self, feature_dim, hidden_dim, num_experts=1, output_dim=200):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        self.feature_extractor = FeatureExtractor()
        self.gating_network = GatingNetwork(feature_dim, num_experts)
        # self.masked_feature_extractor = MaskedFeatureExtractor(feature_dim, hidden_dim, num_experts)
        self.masked_feature_extractor = MaskedLUFeatureExtractor(feature_dim, hidden_dim, num_experts)
        self.classifier = SharedClassifier(hidden_dim, output_dim)

    def forward(self, x):
      features = self.feature_extractor(x)             # [B, feature_dim]
      expert_scores = self.gating_network(features)    # [B, num_experts]

      expert_outputs = []
      for i in range(self.num_experts):
          out = self.masked_feature_extractor(features, expert_id=i)  # [B, hidden_dim]
          expert_outputs.append(out.unsqueeze(1))  # shape [B, 1, hidden_dim]

      expert_outputs = torch.cat(expert_outputs, dim=1)   # [B, num_experts, hidden_dim]
      expert_scores = expert_scores.unsqueeze(2)          # [B, num_experts, 1]

      weighted_features = (expert_outputs * expert_scores).sum(dim=1)  # [B, hidden_dim]
      logits = self.classifier(weighted_features)                      # [B, output_dim]
      return logits



In [None]:
def full_training(model, dataloader, criterion, device, optimizer, num_epochs=50, log_interval=100):
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct, total = 0, 0

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            # print("Outputs: ",outputs)
            # print("Labels: ",labels)
            loss = criterion(outputs, labels)
            # print(loss)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            # print("Predicted ",predicted)
            # print("Labels: ",labels)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if batch_idx % log_interval == 0:
                print(f"Batch [{batch_idx}/{len(dataloader)}] - Loss: {loss.item():.4f}")


        # for name, param in model.named_parameters():
        #   print(f"{name}: requires : {param.requires_grad} | grad not None? {param.grad is not None} | grad norm: {param.grad.norm() if param.grad is not None else 'NA'}")
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = 100. * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")


In [None]:
# def retrain_with_mask(model, dataloader, criterion, optimizer, device, num_epochs=5, log_interval=10):
#     model.train()

#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         correct, total = 0, 0

#         for batch_idx, (images, labels) in enumerate(dataloader):
#             images, labels = images.to(device), labels.to(device)

#             optimizer.zero_grad()
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()
#             _, predicted = outputs.max(1)
#             total += labels.size(0)
#             correct += predicted.eq(labels).sum().item()

#             if batch_idx % log_interval == 0:
#                 print(f"Batch [{batch_idx}/{len(dataloader)}] - Loss: {loss.item():.4f}")

#         epoch_loss = running_loss / len(dataloader)
#         epoch_accuracy = 100. * correct / total
#         print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")


In [None]:

# --- Set device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# --- Load individual datasets ---
cub_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/cubs_cropped/train', transform=transform)
flowers_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/flowers/train', transform=transform)

# --- Merge datasets WITHOUT offsetting class labels ---
# ConcatDataset simply creates a virtual concatenation of the datasets
combined_dataset = ConcatDataset([cub_dataset, flowers_dataset])
# combined_dataset = ConcatDataset([cub_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

# --- Determine total number of classes from combined dataset ---
# Here we assume that each dataset's labels remain unchanged.
# all_targets = []
# for ds in [cub_dataset, flowers_dataset]:
#     all_targets.extend(ds.targets)
#     print(ds.targets)
# # The total number of classes is the maximum label value + 1 (if labels are 0-indexed)
# total_classes = np.max(all_targets) + 1
# print(f"Total classes (combined): {total_classes}")

# --- Initialize the Mixture of Experts model with the combined output dim ---
model = MixtureOfExperts(feature_dim=512, hidden_dim=256, num_experts=2, output_dim=200).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.named_parameters(), lr=0.001)


In [17]:
# --- Full Training on Merged Dataset ---
print("\n==== Training on Combined CUB + Flowers Dataset ====")
full_training(model, combined_loader, criterion, device, optimizer, num_epochs=50, log_interval=10)


==== Training on Combined CUB + Flowers Dataset ====
Batch [0/188] - Loss: 5.2955
Batch [10/188] - Loss: 5.3042
Batch [20/188] - Loss: 5.2983
Batch [30/188] - Loss: 5.3083
Batch [40/188] - Loss: 5.2848
Batch [50/188] - Loss: 5.2003
Batch [60/188] - Loss: 5.0961
Batch [70/188] - Loss: 5.1945
Batch [80/188] - Loss: 5.1006
Batch [90/188] - Loss: 5.3410
Batch [100/188] - Loss: 4.9658
Batch [110/188] - Loss: 5.0655
Batch [120/188] - Loss: 5.2119
Batch [130/188] - Loss: 4.9874
Batch [140/188] - Loss: 4.9123
Batch [150/188] - Loss: 4.8651
Batch [160/188] - Loss: 4.8898
Batch [170/188] - Loss: 4.7365
Batch [180/188] - Loss: 5.0805
Epoch [1/50] - Loss: 5.1071, Accuracy: 1.02%
Batch [0/188] - Loss: 4.7806
Batch [10/188] - Loss: 4.7949
Batch [20/188] - Loss: 4.9834
Batch [30/188] - Loss: 4.6368
Batch [40/188] - Loss: 4.6505
Batch [50/188] - Loss: 4.9260
Batch [60/188] - Loss: 4.7647
Batch [70/188] - Loss: 4.8973
Batch [80/188] - Loss: 4.7679
Batch [90/188] - Loss: 5.0299
Batch [100/188] - Loss: 

KeyboardInterrupt: 

In [None]:
# save the model that has been trained in drive
torch.save(model.state_dict(), "/content/drive/MyDrive/model_1_expert_1_task_50_epochs.pth")
