In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# 1. Feature Extractor (ResNet-18)
# ----------------------------
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x):
        x = self.feature_extractor(x)
        return x.view(x.size(0), -1)

# ----------------------------
# 2. Gating Network
# ----------------------------
# class GatingNetwork(nn.Module):
#     def __init__(self, input_dim, num_experts):
#         super(GatingNetwork, self).__init__()
#         self.fc = nn.Linear(input_dim, num_experts)

#     # def forward(self, x):
#     #     logits = self.fc(x)
#     #     # gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
#     #     expert_scores = F.softmax((logits), dim=1)
#     #     expert_id = torch.argmax(expert_scores, dim=1)
#     #     return expert_id, expert_scores

#     def forward(self, x):
#       logits = self.fc(x)
#       gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
#       expert_scores = F.softmax((logits + gumbel_noise) / 0.5, dim=-1)
#       # print(expert_scores)
#       # print("expert_scores: ",expert_scores)
#       return expert_scores  # shape [B, num_experts]

# ----------------------------
# 2. Masked Feature Extractor
# ----------------------------
# class MaskedLUFeatureExtractor(nn.Module):
#     def __init__(self, input_dim, hidden_dim, num_experts, threshold=0.5, num_layers=4):
#         super(MaskedLUFeatureExtractor, self).__init__()
#         self.num_experts = num_experts
#         self.num_layers = num_layers
#         self.hidden_dim = hidden_dim
#         self.threshold = threshold

#         # Create real linear layers (separate per expert and layer)
#         self.linear_layers = nn.ModuleDict()

#         # Create L, U, and b for mask generation
#         self.L_matrices = nn.ParameterDict()
#         self.U_matrices = nn.ParameterDict()
#         self.bias_masks = nn.ParameterDict()

#         for expert_id in range(num_experts):
#             for layer_idx in range(num_layers):
#                 d_in = input_dim if layer_idx == 0 else hidden_dim
#                 d_out = hidden_dim
#                 k = max(d_in, d_out)

#                 # Initialize linear layer for this expert and layer
#                 lin_key = f"expert{expert_id}_layer{layer_idx}_linear"
#                 self.linear_layers[lin_key] = nn.Linear(d_in, d_out)

#                 # L, U, and bias mask
#                 l_key = f"expert{expert_id}_layer{layer_idx}_L"
#                 u_key = f"expert{expert_id}_layer{layer_idx}_U"
#                 b_key = f"expert{expert_id}_layer{layer_idx}_b"

#                 self.L_matrices[l_key] = nn.Parameter(torch.tril(torch.randn(d_out, k)))
#                 self.U_matrices[u_key] = nn.Parameter(torch.randn(k, d_in))
#                 self.bias_masks[b_key] = nn.Parameter(torch.zeros(d_out))

#     def forward(self, x, expert_id):
#         for i in range(self.num_layers):
#             lin_key = f"expert{expert_id}_layer{i}_linear"
#             l_key = f"expert{expert_id}_layer{i}_L"
#             u_key = f"expert{expert_id}_layer{i}_U"
#             b_key = f"expert{expert_id}_layer{i}_b"

#             linear = self.linear_layers[lin_key]
#             weight = linear.weight         # [d_out, d_in]
#             bias = linear.bias             # [d_out]

#             # Generate mask using sigmoid(50 * L @ U)
#             L = self.L_matrices[l_key]     # [d_out, k]
#             U = self.U_matrices[u_key]     # [k, d_in]
#             b_mask = self.bias_masks[b_key]  # [d_out]

#             W_mask = torch.sigmoid(10 * (L @ U))
#             # print(W_mask)          # [d_out, d_in]
#             b_mask = torch.sigmoid(b_mask)                  # [d_out]

#             # Apply mask
#             masked_weight = weight * W_mask                 # [d_out, d_in]
#             masked_bias = bias * b_mask                     # [d_out]

#             # Linear transformation with masked weights and biases
#             x = F.linear(x, masked_weight, masked_bias)
#             x = F.relu(x)

#         return x

# class MaskedLUFeatureExtractor(nn.Module):
#     def __init__(self, input_dim, hidden_dim, num_experts, num_layers=4):
#         super(MaskedLUFeatureExtractor, self).__init__()
#         self.num_experts = num_experts
#         self.num_layers = num_layers
#         self.hidden_dim = hidden_dim
#         self.input_dim = input_dim

#         self.linear_layers = nn.ModuleDict()

#         for expert_id in range(num_experts):
#             for layer_idx in range(num_layers):
#                 d_in = input_dim if layer_idx == 0 else hidden_dim
#                 d_out = hidden_dim

#                 lin_key = f"expert{expert_id}_layer{layer_idx}_linear"
#                 self.linear_layers[lin_key] = nn.Linear(d_in, d_out)

#                 # Create and register weight and bias masks
#                 torch.manual_seed(expert_id * 10 + layer_idx)  # for reproducibility

#                 weight_mask = torch.ones(d_out, d_in)
#                 bias_mask = torch.ones(d_out)

#                 self.register_buffer(f"weight_mask_{expert_id}_{layer_idx}", weight_mask)
#                 self.register_buffer(f"bias_mask_{expert_id}_{layer_idx}", bias_mask)

#     def forward(self, x, expert_id):
#         for i in range(self.num_layers):
#             lin_key = f"expert{expert_id}_layer{i}_linear"
#             linear = self.linear_layers[lin_key]

#             weight = linear.weight
#             bias = linear.bias

#             W_mask = getattr(self, f"weight_mask_{expert_id}_{i}")
#             b_mask = getattr(self, f"bias_mask_{expert_id}_{i}")

#             masked_weight = weight * W_mask.detach()
#             masked_bias = bias * b_mask.detach()

#             x = F.linear(x, masked_weight, masked_bias)
#             x = F.relu(x)

#         return x

class MaskedLUFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts):
        super(MaskedLUFeatureExtractor, self).__init__()
        self.num_experts = num_experts

        self.linear_layers = nn.ModuleDict()

        for expert_id in range(num_experts):
            self.linear_layers[f"expert{expert_id}_layer0"] = nn.Linear(input_dim, hidden_dim)
            self.linear_layers[f"expert{expert_id}_layer1"] = nn.Linear(hidden_dim, hidden_dim)
            self.linear_layers[f"expert{expert_id}_layer2"] = nn.Linear(hidden_dim, hidden_dim)
            self.linear_layers[f"expert{expert_id}_layer3"] = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, expert_id):
        x = F.relu(self.linear_layers[f"expert{expert_id}_layer0"](x))
        x = F.relu(self.linear_layers[f"expert{expert_id}_layer1"](x))
        x = F.relu(self.linear_layers[f"expert{expert_id}_layer2"](x))
        x = F.relu(self.linear_layers[f"expert{expert_id}_layer3"](x))
        return x

# ----------------------------
# 4. Shared Classification Layer
# ----------------------------
class SharedClassifier(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(SharedClassifier, self).__init__()
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
        # return F.softmax(self.fc(x), dim=1)

# ----------------------------
# 5. Mixture of Experts
# ----------------------------
class MixtureOfExperts(nn.Module):
    def __init__(self, feature_dim, hidden_dim, num_experts=1, output_dim=200):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        self.feature_extractor = FeatureExtractor()
        self.masked_feature_extractor = MaskedLUFeatureExtractor(feature_dim, hidden_dim, num_experts)
        self.classifier = SharedClassifier(hidden_dim, output_dim)

    def forward(self, x):
      features = self.feature_extractor(x)             # [B, feature_dim]
      features = self.masked_feature_extractor(features, expert_id=0)  # [B, hidden_dim
      logits = self.classifier(features)                      # [B, output_dim]
      return logits



In [None]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
  def __init__(self,data,labels):
    self.data = data
    self.labels = labels

  def __len__(self):
    return len(self.data)

  def __getitem__(self,idx):
    return self.data[idx],self.labels[idx]


In [None]:
def full_training(model, dataloader, criterion, device,optmizer, num_epochs=50, log_interval=100):
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct, total = 0, 0

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            # print("Outputs: ",outputs)
            # print("Labels: ",labels)
            loss = criterion(outputs, labels)
            # print(loss)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            # print("Predicted ",predicted)
            # print("Labels: ",labels)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if batch_idx % log_interval == 0:
                print(f"Batch [{batch_idx}/{len(dataloader)}] - Loss: {loss.item():.4f}")


        # for name, param in model.named_parameters():
        #   print(f"{name}: requires : {param.requires_grad} | grad not None? {param.grad is not None} | grad norm: {param.grad.norm() if param.grad is not None else 'NA'}")
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = 100. * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")


In [None]:
# --- Set device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Transforms ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# --- Load individual datasets ---

cub_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/cubs_cropped/train', transform=transform)
# print(cub_dataset)

# image_paths = [sample[0] for sample in cub_dataset.samples]
# class_names = [cub_dataset.classes[label] for _, label in cub_dataset.samples]
# print(cub_dataset.class_to_idx)
# print(image_paths)
# print(class_names)

# train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# combined_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# cub_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data', transform=transform)
# flowers_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/flowers/train', transform=transform)


# --- Merge datasets WITHOUT offsetting class labels ---
# combined_dataset = ConcatDataset([cub_dataset, flowers_dataset])
# combined_loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

combined_loader = DataLoader(cub_dataset, batch_size=32, shuffle=True)


# --- Initialize the Mixture of Experts model with the combined output dim ---
model = MixtureOfExperts(feature_dim=512, hidden_dim=256, num_experts=1, output_dim=200).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.named_parameters(), lr=0.001)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 157MB/s]


In [None]:
# import matplotlib.pyplot as plt

# images, labels = next(iter(combined_loader))

# # Define class names if you want to display them
# class_names = cub_dataset.classes

# plt.figure(figsize=(10, 4))
# for i in range(4):
#     img = images[i].permute(1, 2, 0).numpy()  # convert to HWC format
#     plt.subplot(1, 4, i + 1)
#     plt.imshow(img)
#     plt.title(class_names[labels[i]])
#     plt.axis('off')
# plt.tight_layout()
# plt.show()

In [None]:
# --- Full Training on Merged Dataset ---
# print("\n==== Training on Combined CUB + Flowers Dataset ====")
full_training(model, combined_loader, criterion, device, optimizer,num_epochs=50, log_interval=10)

In [None]:
# save the model that has been trained in drive
torch.save(model.state_dict(), "/content/drive/MyDrive/model_4_layer_nomask.pth")
