In [None]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class GatingMechanism(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingMechanism, self).__init__()
        self.gate = nn.Linear(input_dim, num_experts).to(device)

    def forward(self, x):
        x_mean = x.mean(dim=1)
        gate_scores = F.softmax(self.gate(x_mean), dim=-1)  # Shape: [batch_size, num_experts]
        return gate_scores.argmax(dim=-1)  # Shape: [batch_size]


In [None]:
import torch
import torch.nn as nn

class MoEModelWithPooling(nn.Module):
    def __init__(self, experts, input_dim):
        super().__init__()
        self.experts = experts
        # Update to access layers based on the specific model description shared
        self.num_layers = len(experts[0].base_model.model.model.layers)  # Correct path to access layers
        self.gating = GatingMechanism(input_dim, len(experts))
        self.pooling = nn.AdaptiveAvgPool1d(1).to(device)  # Example pooling layer
        self.output_layer = nn.Linear(4096, 4).to(device)
        self.softmax = nn.Softmax(dim=1).to(device)

    def forward(self, input_ids):
        # Update to correctly access the initial embedding layer
        # Assuming wte is the embedding layer, adjust if different in your model
        # print(input_ids)
        x = self.experts[0].base_model.model.model.embed_tokens(input_ids)
        # print('x', x.shape)
        for i in range(self.num_layers):
            expert_indices = self.gating(x)
            # print('x', x.shape)
            layer_output = torch.zeros_like(x)
            # print('x', x.shape)
            for idx, expert in enumerate(self.experts):
                mask = (expert_indices == idx).unsqueeze(-1).unsqueeze(1).half()
                # print('mask', mask.shape)
                expert_input = x * mask
                # print('expert_input', expert_input.shape, expert_input.dtype)
                # Accessing the i-th layer correctly according to the model structure
                # exp_out = expert.base_model.model.model.layers[i]
                # print(exp_out)
                expert_output = expert.base_model.model.model.layers[i](expert_input)[0]
                # print('expert_output', expert_output.shape, expert_output.dtype)
                layer_output += expert_output * mask
            x = layer_output

        # print('x1',x.shape)
        x = x.transpose(1, 2)  # Adjust dimensions for pooling
        x = self.pooling(x).squeeze(2)
        # print('x2', x.shape)
        x = self.output_layer(x)
        x = self.softmax(x)

        return x

# GatingMechanism definition assumed to be implemented elsewhere


In [None]:
# %%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes

Collecting unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-crdv267y/unsloth_c2aa386fdb3b40af955a8cfd4ae05667
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-crdv267y/unsloth_c2aa386fdb3b40af955a8cfd4ae05667
  Resolved https://github.com/unslothai/unsloth.git to commit 47ffd39abd02338e8a5f226d0f529347fb7e5f89
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting xformers<0.0.26
  Downloading xformers-0.0.25.post1-cp310-cp310-manylinux2014_x86_64.whl (222.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m222.5/222.5 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting trl
  Downloading trl-0.8.6-py3-none-any.whl (245 kB)
[2K   

In [None]:
# Load pre-trained models
from unsloth import FastLanguageModel

max_seq_length = 256 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model1, tokenizer = FastLanguageModel.from_pretrained("unsloth_domain",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

model2, tokenizer = FastLanguageModel.from_pretrained("ai2_arc_instruction_tuned_mistral_7b",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

==((====))==  Unsloth: Fast Mistral patching release 2024.4
   \\   /|    GPU: Tesla V100-PCIE-16GB. Max memory: 15.773 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 7.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Unsloth 2024.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.
Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


==((====))==  Unsloth: Fast Mistral patching release 2024.4
   \\   /|    GPU: Tesla V100-PCIE-16GB. Max memory: 15.773 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 7.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


In [None]:
for param in model1.parameters():
    param.requires_grad = False

for param in model2.parameters():
    param.requires_grad = False

models = [model1, model2]

In [None]:
num_layers = len(model1.base_model.model.model.layers)

moe_model = MoEModelWithPooling(models, input_dim=4096)

In [None]:
print(moe_model)

MoEModelWithPooling(
  (gating): GatingMechanism(
    (gate): Linear(in_features=4096, out_features=2, bias=True)
  )
  (pooling): AdaptiveAvgPool1d(output_size=1)
  (output_layer): Linear(in_features=4096, out_features=4, bias=True)
  (softmax): Softmax(dim=1)
)


In [None]:
# from torch.cuda.amp import GradScaler

# scaler = GradScaler()

# temp = torch.ones((2,8), dtype=torch.int64).to(device)
# criterion = torch.nn.CrossEntropyLoss()
# labels = torch.rand(2, 4).float().to(device)

# optimizer = torch.optim.Adam(moe_model.parameters(), lr=1e-3)

# optimizer.zero_grad()
# with torch.cuda.amp.autocast():
#     output = moe_model(temp).float()
#     loss = criterion(output, labels.float())

# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset

dataset_location = 'medmcqa-prompts'

train_dataset = load_from_disk(f"{dataset_location}/train_prompts.hf")
# test_dataset = load_from_disk(f"{dataset_location}/test_prompts.hf")
eval_dataset = load_from_disk(f"{dataset_location}/eval_prompts.hf")

In [None]:
# train = []
# val = []
# count = 0
# for i in train_dataset:
#     train.append(i)
#     count += 1
#     if count >= 100:
#         break

# count = 0
# for i in eval_dataset:
#     val.append(i)
#     count += 1
#     if count >= 100:
#         break

# train_dataset = ''
# eval_dataset = ''

In [None]:
# print(train[0])

In [None]:
from torch.utils.data import DataLoader, Dataset

class MCQDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)  # Changed to float for one-hot encoding
        return item

    def __len__(self):
        return len(self.labels)

# Function to encode the data
def encode_data(tokenizer, prompts):
    encodings = tokenizer(prompts, truncation=True, padding=True, max_length=128)
    return encodings

# Prepare the data for tokenization
prompts = [item['prompt'] for item in train_dataset]
labels = [item['label_one_hot'] for item in train_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
train_set = MCQDataset(encodings, labels)

# DataLoader
train_loader = DataLoader(train_set, batch_size=6, shuffle=True)


prompts = [item['prompt'] for item in eval_dataset]
labels = [item['label_one_hot'] for item in eval_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
eval_set = MCQDataset(encodings, labels)

# DataLoader
val_loader = DataLoader(eval_set, batch_size=6, shuffle=True)

In [None]:
for i, batch in enumerate(train_loader):
    print("Batch", i)
    print("Input IDs:", batch['input_ids'].shape)
    print("Attention Mask:", batch['attention_mask'].shape)
    print("Labels:", batch['labels'].shape)

    # Print the actual content of the first example in the batch
    if i == 0:
        print("First input IDs example:", batch['input_ids'][0])
        print("First attention mask example:", batch['attention_mask'][0])
        print("First label example:", batch['labels'][0])

    # Optionally, break after a few batches to avoid too much output
    if i == 2:
        break

Batch 0
Input IDs: torch.Size([6, 128])
Attention Mask: torch.Size([6, 128])
Labels: torch.Size([6, 4])
First input IDs example: tensor([    1, 28705,    13,  2287, 22478, 28747,    13,  2287,   330,  7749,
         7567,   395,  8039,  1212,   890,  2654,   265, 28708,   395,  4242,
        28723,  1418,  4819, 28719,   806,   385,  6541, 28725,   277,  2857,
          335, 13441,  3042,  1419,   304,  2823,  2458, 28723,  1418, 19869,
         5643,   302,  3042, 28725,   690,  1235,   459,   506,   430,  4206,
          504,   294,   277,   672, 28804,    13,  4018, 28747,    13, 28741,
        28723,  7787,   509,    13, 28760, 28723,   334,   586, 28717,    13,
        28743, 28723,  8990,  2737,   270,   455,   570,  1140,   375,   425,
          472,    13, 28757, 28723,  1234, 28726, 28760, 28750,    13,    13,
         2287,   733, 16289, 28793,   318,  5303,   456,  9713,  4191, 27710,
        22478,   304,  3084,   272,  4714,  3551,   575,   302,  2308,  2877,
        28732

In [None]:
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = GradScaler()

def print_memory_usage():
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

def compute_loss(outputs, labels, attention_mask):
    # Flatten the outputs and labels for loss calculation
    active_loss = attention_mask.view(-1) == 1  # Mask out padded tokens
    active_logits = outputs.view(-1, outputs.size(-1))[active_loss]
    active_labels = labels.view(-1)[active_loss]
    return F.cross_entropy(active_logits, active_labels)

def train_and_validate(model, train_loader, val_loader, epochs=3):
    scaler = GradScaler()
    device = torch.device("cuda")
    # model = model.to(device)  # Ensures model and all submodules are float32
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_train_loss = 0
        total_train_correct = 0
        train_samples = 0

        model.train()
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]", unit="batch")
        for i, batch in enumerate(train_loader):
            input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
            train_samples += labels.size(0)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(input_ids).float()
                loss = criterion(output, labels.float())
                predictions = torch.argmax(output, dim=1)
                labels_indices = torch.argmax(labels, dim=1)
                # print(output)
                # print(predictions)
                # print(labels)
                total_train_correct += (predictions == labels_indices).sum().item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()

            train_pbar.set_postfix(loss=loss.item(), temp_acc=100 * total_train_correct / train_samples)

            if i % 1000 == 0:
                print(i, loss.item())
                print(f"Temp accuracy: ", total_train_correct / train_samples * 100)

        avg_train_loss = total_train_loss / len(train_loader)
        train_accuracy = total_train_correct / train_samples * 100
        print(f"Training Accuracy: ", train_accuracy)
        print(f"Epoch {epoch+1}, Loss: {avg_train_loss}")

        model.eval()
        total_val_loss, val_samples, total_val_correct = 0, 0, 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
                with torch.cuda.amp.autocast():
                    outputs = model(input_ids).float()
                    val_loss = criterion(outputs, labels.float())
                    predictions = torch.argmax(outputs, dim=1)
                    total_val_correct += (predictions == labels).sum().item()

                total_val_loss += val_loss.item()
                val_samples += labels.size(0)

        avg_val_loss = total_val_loss / len(val_loader)
        val_accuracy = total_val_correct / val_samples * 100
        print(f"Validation Accuracy: ", val_accuracy)
        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}")

# Example usage
train_and_validate(moe_model, train_loader, val_loader)


Epoch 1 [TRAIN]:   0%|          | 0/30471 [00:03<?, ?batch/s, loss=1.39, temp_acc=33.3]

0 1.3887276649475098
Temp accuracy:  33.33333333333333


Epoch 1 [TRAIN]:   0%|          | 0/30471 [01:55<?, ?batch/s, loss=1.34, temp_acc=27.4]