In [1]:
import sys
sys.path.append('../')

from scripts.losses import *
from scripts.models import *
import numpy as np
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter



In [2]:
input_dim = 100
num_samples = 1
num_experts = 3
num_classes = 3
hidden_dim = 20

In [27]:

# create dummy data
X = torch.randn(num_samples, input_dim)

# create dummy targets
y = torch.randint(0, num_classes, (num_samples,))

# create one hot encoding for targets
y_onehot = F.one_hot(y, num_classes=num_classes)

In [28]:
print(f"y: {y}")
print(f"y shape: {y.shape}")
print(f"y_onehot: {y_onehot}")
print(f"y_onehot shape: {y_onehot.shape}")


y: tensor([2])
y shape: torch.Size([1])
y_onehot: tensor([[0, 0, 1]])
y_onehot shape: torch.Size([1, 3])


In [29]:
# get device
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [30]:
gating_loss = MSEGatingLoss()
expert_loss = nn.CrossEntropyLoss()

In [31]:
# create dummy model
model = MixtureOfExperts(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes, num_experts=num_experts)


In [32]:
# print types of each data
print(f"X type: {X.dtype}")
print(f"y type: {y.dtype}")
print(f"y_onehot type: {y_onehot.dtype}")

X type: torch.float32
y type: torch.int64
y_onehot type: torch.int64


In [44]:
# get output from model
output = model(X)


mixture_out, gating_out, expert_out = output


# print shapes of each output
print(f"mixture_out shape: {mixture_out.shape} -- y_shape: {y.shape}")
print(f"gating_out shape: {gating_out.shape} -- y_onehot_shape: {y_onehot.shape}")
print(f"")

# get gating loss
g_loss = gating_loss(gating_out, y_onehot.float())

# get expert loss
e_loss = expert_loss(mixture_out, y)

# get total loss
total_loss = g_loss + e_loss

print(f"output: {output}")

expert_outputs_shape: torch.Size([1, 3, 3])
gating_weights_shape: torch.Size([1, 3])
mixture_output_shape: torch.Size([1, 3])
mixture_out shape: torch.Size([1, 3]) -- y_shape: torch.Size([1])
gating_out shape: torch.Size([1, 3]) -- y_onehot_shape: torch.Size([1, 3])

output: (tensor([[ 0.2149, -0.0186, -0.2226]], grad_fn=<SumBackward1>), tensor([[0.3642, 0.4012, 0.2346]], grad_fn=<SoftmaxBackward0>), tensor([[[ 0.1415, -0.3374, -0.2104],
         [ 0.3098,  0.4765, -0.2298],
         [ 0.1663, -0.3705, -0.2293]]], grad_fn=<StackBackward0>))


In [45]:
print(f"total_loss.dtype: {total_loss.dtype}")

# total_loss = total_loss.float()
total_loss.backward()

total_loss.dtype: torch.float32


In [47]:
from tqdm import tqdm

In [48]:
writer = SummaryWriter()
# test writer dummy data
writer.add_scalar("test", 1, 1)

In [None]:
epochs = 1 
batch_size = 1
lr = 0.001

for epoch in tqdm(range(epochs), desc="Epochs"):
    # set model to train
    model.train()

    # track losses, predictions and labels
    total_expert_loss = 0.0
    total_gating_loss = 0.0
    all_preds = []
    all_labels = []

    # loop over data from dataloader

    for input, true_gating_labels, labels in tqdm(dataloader, desc="Batches", leave=False):
        # get data to device
        input = input.to(device)
        true_gating_labels = true_gating_labels.to(device)
        labels = labels.to(device)

        # get output from model
        mixture_out, gating_out, expert_out = model(input)
        
        # expert out for debugging

        # get gating loss
        gating_loss = gating_loss(gating_out, true_gating_labels.float())

        # get expert loss
        expert_loss = expert_loss(mixture_out, labels)

        # get total loss
        total_loss = gating_loss + expert_loss

        # zero gradients
        optimizer.zero_grad()

        # backpropagate
        total_loss.backward()

        # update weights
        optimizer.step()
        