## GEM Architecture Train on Banking 77 dataset

## PIP INSTALLS
---

In [1]:
##@ All necessary pip installs
!pip install -qU transformers datasets accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m342.1/342.1 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25h

## HANDLING WARNINGS
---

In [2]:
import warnings 
warnings.filterwarnings('ignore')

## IMPORTS

---

In [3]:
##@ Core Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoModel, AutoTokenizer, 
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from sklearn.cluster import MiniBatchKMeans
from tqdm.auto import tqdm
import numpy as np

## ARCHITECTURE CONFIG
---

In [4]:
# 1. Configuration
class GEMConfig:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hidden_size = 768
        self.num_domains = 8
        self.cluster_size = 256
        self.num_classes = 77
        self.num_attention_heads = 12
        self.max_seq_length = 128
        self.batch_size = 32 * max(1, torch.cuda.device_count())  # Dynamic scaling
        self.epochs = 10
        self.learning_rate = 2e-5
        self.gradient_accumulation_steps = 2

In [5]:
# 2. Core Components
class QuantizedBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        return self.dequant(self.quant(outputs.last_hidden_state))

class TokenRouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.clusterer = MiniBatchKMeans(n_clusters=config.cluster_size)
        self.W_r = nn.Parameter(torch.randn(config.num_domains, config.hidden_size))
        self.threshold = 0.65

    def forward(self, x):
        # Device-safe clustering
        cluster_input = x.detach().cpu().numpy().reshape(-1, x.shape[-1])
        cluster_ids = self.clusterer.fit_predict(cluster_input)
        cluster_ids = torch.tensor(cluster_ids, device=self.config.device).reshape(x.shape[:2])
        
        # Device-aware projections
        domain_logits = torch.einsum('bsh,nh->bsn', x, self.W_r.to(x.device))
        domain_probs = F.softmax(domain_logits, dim=-1)
        routing_mask = (domain_probs.max(-1).values > self.threshold).long()
        
        return domain_probs, routing_mask, cluster_ids

class SCAR(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // self.num_heads
        self.qkv = nn.Linear(config.hidden_size, 3*config.hidden_size)
        self.out = nn.Linear(config.hidden_size, config.hidden_size)

    def create_mask(self, cluster_ids, routing_mask):
        cluster_mask = (cluster_ids.unsqueeze(-1) == cluster_ids.unsqueeze(-2))
        domain_mask = (routing_mask.unsqueeze(-1) == routing_mask.unsqueeze(-2))
        return cluster_mask | domain_mask

    def forward(self, x, cluster_ids, routing_mask):
        B, N, _ = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        mask = self.create_mask(cluster_ids, routing_mask).unsqueeze(1)
        attn = attn.masked_fill(~mask, -1e9)
        
        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        return self.out(x)

In [6]:
# 3. Complete GEM Model
class GEM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert = QuantizedBERT()
        self.router = TokenRouter(config)
        self.scar = SCAR(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_classes)
        
        # Teacher model with proper device placement
        self.teacher = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=config.num_classes
        ).eval().to(config.device).requires_grad_(False)

    def forward(self, input_ids, attention_mask=None):
        x = self.bert(input_ids, attention_mask=attention_mask)
        domain_probs, routing_mask, cluster_ids = self.router(x)
        x = self.scar(x, cluster_ids, routing_mask)
        return self.classifier(x[:, 0, :])

    def qakp_loss(self, outputs, labels, input_ids):
        task_loss = F.cross_entropy(outputs, labels)
        quant_error = F.mse_loss(self.bert.quant(self.bert.dequant(outputs)), outputs)
        
        with torch.no_grad():
            teacher_logits = self.teacher(input_ids).logits
        
        kd_loss = F.kl_div(
            F.log_softmax(outputs, dim=-1),
            F.softmax(teacher_logits, dim=-1),
            reduction='batchmean'
        )
        
        return task_loss + 0.3*quant_error + 0.7*kd_loss

In [7]:
# 6. Evaluation and Deployment (Fixed)
def prepare_dataloaders(config):
    dataset = load_dataset("banking77")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    def tokenize(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=config.max_seq_length
        )
    
    dataset = dataset.map(tokenize, batched=True)
    
    def collate(batch):
        return {
            'input_ids': torch.stack([torch.tensor(x['input_ids']) for x in batch]),
            'attention_mask': torch.stack([torch.tensor(x['attention_mask']) for x in batch]),
            'labels': torch.tensor([x['label'] for x in batch])
        }
    
    train_loader = DataLoader(
        dataset['train'],
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collate
    )
    
    test_loader = DataLoader(
        dataset['test'],
        batch_size=config.batch_size,
        collate_fn=collate
    )
    
    return train_loader, test_loader

## TRAINING LOOP 
---

In [8]:
##@ Training Loop with Multi-GPU Support
def train_model():
    config = GEMConfig()
    train_loader, test_loader = prepare_dataloaders(config)
    
    model = GEM(config)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
    model.to(config.device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_loader)*config.epochs
    )
    
    model.train()
    for epoch in range(config.epochs):
        total_loss = 0
        for step, batch in enumerate(tqdm(train_loader)):
            inputs = batch['input_ids'].to(config.device)
            masks = batch['attention_mask'].to(config.device)
            labels = batch['labels'].to(config.device)
            
            outputs = model(inputs, attention_mask=masks)
            loss = model.module.qakp_loss(outputs, labels, inputs)
            
            loss.backward()
            if (step+1) % config.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(train_loader):.4f}")
    
    return model

##@ Evaluation & Saving
def evaluate_and_save(model):
    config = GEMConfig()
    _, test_loader = prepare_dataloaders(config)
    
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['input_ids'].to(config.device)
            masks = batch['attention_mask'].to(config.device)
            labels = batch['labels'].to(config.device)
            
            outputs = model(inputs, attention_mask=masks)
            preds = outputs.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    print(f"Final Accuracy: {100*correct/total:.2f}%")
    torch.save(model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), 
               "gem_model.pth")

if __name__ == "__main__":
    trained_model = train_model()
    evaluate_and_save(trained_model)

README.md:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/298k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/93.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3080 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/10003 [00:00<?, ? examples/s]

Map:   0%|          | 0/3080 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/157 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1 | Avg Loss: 4.2777


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2 | Avg Loss: 3.0262


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3 | Avg Loss: 1.9194


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4 | Avg Loss: 1.5360


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5 | Avg Loss: 1.3822


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 6 | Avg Loss: 1.3048


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 7 | Avg Loss: 1.2548


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 8 | Avg Loss: 1.2228


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 9 | Avg Loss: 1.1982


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 10 | Avg Loss: 1.1828


Map:   0%|          | 0/3080 [00:00<?, ? examples/s]

Final Accuracy: 92.56%
