# GPT-OSS

In [None]:
import torch
import torch.nn as nn
import torchtune.modules as modules
import torch.nn.functional as F

import torch
import torch.nn as nn
import torchtune.modules as modules

class GPT_OSS_Block(nn.Module):
    def __init__(self, d_dim, max_seq_len, num_heads=8, num_kv_heads=4):
        super().__init__()
        
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_dim // num_heads
        
        # Rotary embeddings applied per head
        self.rope = modules.RotaryPositionalEmbeddings(
            dim=self.head_dim, max_seq_len=max_seq_len
        )
        
        # Define the projection layers
        q_proj = nn.Linear(d_dim, d_dim, bias=False)
        k_proj = nn.Linear(d_dim, self.num_kv_heads * self.head_dim, bias=False)
        v_proj = nn.Linear(d_dim, self.num_kv_heads * self.head_dim, bias=False)
        out_proj = nn.Linear(d_dim, d_dim, bias=False)
        
        # Torchtune MultiHeadAttention requires you to provide them
        self.gqa = modules.attention.MultiHeadAttention(
            embed_dim=d_dim,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_dim=self.head_dim,
            q_proj=q_proj,
            k_proj=k_proj,
            v_proj=v_proj,
            output_proj=out_proj,
        )
        
        self.rmsnorm1 = modules.RMSNorm(d_dim)
        self.rmsnorm2 = modules.RMSNorm(d_dim)
        
        self.moe = MoE(d_dim, d_ff=d_dim*2)

    def forward(self, x):
        B, L, D = x.shape
        
        embs = self.rmsnorm1(x)
        
        # --- Project to Q, K, V ---
        q = self.gqa.q_proj(embs).view(B, L, self.num_heads, self.head_dim)
        k = self.gqa.k_proj(embs).view(B, L, self.num_kv_heads, self.head_dim)
        v = self.gqa.v_proj(embs).view(B, L, self.num_kv_heads, self.head_dim)
        # print(q.shape)
        # --- Apply RoPE ---
        q = self.rope(q)
        k = self.rope(k)
        # print(q.shape)
        # --- Run scaled dot product attention (torchtune already has helper) ---
        attn_out = self.gqa._attention_call(q, k, v,mask=None,        # no mask here
            dropout_p=0.1,    # inference
            is_causal=True)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        attn_out = self.gqa.output_proj(attn_out)
        # Residual + MoE
        # print(attn_out.shape)
        embs = attn_out + embs
        embs = embs.view(B, L, -1)
        
        embs = self.rmsnorm2(embs)
        embs = embs + self.moe(embs)
        
        return embs


        
class MoE(nn.Module):
    
    def __init__(self, d_dim, d_ff, num_experts=32, top_k = 4, capacity_factor = 1.25, dropout=0.1):
        super().__init__()
        
        self.num_experts = num_experts
        self.top_k = top_k
        
        self.router = nn.Linear(d_dim, num_experts)
        
        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(d_dim, d_ff),
            nn.SiLU(),
            nn.Linear(d_ff, d_dim),
            nn.Dropout(dropout)
        ) for _ in range(self.num_experts)])
        
    def forward(self, x):
        
        bsz, seq_len, d_model = x.shape
        
        logits = self.router(x)
        
        gate_probs = F.softmax(logits, dim=-1)
        
        topk_vals, topk_idx = torch.topk(gate_probs, self.top_k, dim=-1)
        
        topk_vals = topk_vals/topk_vals.sum(dim=-1, keepdim=True)
        
        out = torch.zeros_like(x)
        for k in range(self.top_k):
            idx =  topk_idx[..., k]
            prob = topk_vals[..., k].unsqueeze(-1)
            
            for expert_id in range(self.num_experts):
                mask = (idx==expert_id)
                
                if mask.any():
                    expert_in = x[mask]
                    expert_out = self.experts[expert_id](expert_in)
                    out[mask] += prob[mask]*expert_out
        
        return out
    
moe = MoE(128, 256)
x = torch.randn((1, 100, 128))
out = moe(x)
print(out.shape)

# input_ids = torch.randint(0, 100, (1, 100))
model = GPT_OSS_Block(128, 100)
out = model(x)
print(out.shape)

torch.Size([1, 100, 128])
torch.Size([1, 100, 128])


# Reasoning + Clarification Model Testing

In [None]:
from unsloth import FastLanguageModel
import torch

# Load the trained model from the output directory
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "outputs",  # Your output directory
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

# Save as GGUF for Ollama
model.save_pretrained_gguf("model", tokenizer, quantization_method="q4_k_m")

print("Model saved as GGUF format in 'model' directory!")

# Qformer Library

In [1]:
import torch
from qformer import QFormer

In [2]:
x = torch.randn((1, 32, 512))

img = torch.randn((1,3,224,224))

# Create an instance of the QFormer model with the following parameters:
# - input_size: 512
# - num_heads: 8
# - num_layers: 8
# - dropout: 0.1
# - num_classes: 2
# - num_patches: 2
q_former = QFormer(512, 8, 8, 0.1, 2, 2)
y = q_former(x, img)
print(y.shape)

torch.Size([1, 32, 512])


In [None]:
with open("Train_data.jsonl", 'r') as f:
    # while f.next():
    lines = f.readlines()

print(lines)

In [None]:
from datasets import Dataset

dataset = Dataset.from_dict()

# Surprise Mechanism

In [17]:
import torch
import torch.nn as nn


class SurpriseMechanism(nn.Module):
    
    def __init__(self, input_dim, num_slots, slot_dim, learning_rate = 0.6):
        super().__init__()
        
        self.slots = torch.randn((1, num_slots, slot_dim))
        
        self.fc1 = nn.Linear(input_dim, slot_dim)
        
        self.lr = learning_rate
    
    def forward(self, x):
        
        proj = self.fc1(x)
    
        
        surprise = nn.functional.mse_loss(self.slots, proj.expand_as(self.slots), reduction='none').sum(dim=-1)
        
        min_index = torch.argmin(surprise, dim=-1)
        print(min_index)
        
        self.slots.data[0][min_index] = (1-self.lr)*self.slots.data[0][min_index] + self.lr*proj
        
        return surprise        
        


model = SurpriseMechanism(4,2,8)

x_t = torch.randn((1, 4))

out = model(x_t)
print(out)

tensor([1])
tensor([[11.3411,  6.9629]], grad_fn=<SumBackward1>)


In [15]:
model.slots.data.shape

torch.Size([1, 2, 8])

In [20]:
slots = torch.randn((1, 2, 4))
slots

tensor([[[ 0.2288,  0.0145,  0.2790,  1.2779],
         [ 1.9375, -0.7367, -2.1079, -0.0094]]])

In [19]:
x_t.expand_as(slots).shape

torch.Size([1, 2, 4])

In [23]:
x_t

tensor([[-0.1775,  1.4299,  0.2057, -0.6213]])

In [22]:
x_t.expand_as(slots)

tensor([[[-0.1775,  1.4299,  0.2057, -0.6213],
         [-0.1775,  1.4299,  0.2057, -0.6213]]])

# HRM Model

In [None]:
import torch
import torch.nn as nn

class HRM(nn.Module):
    
    def __init__(self, hidden_size, vocab_size, context_length, output_size, h_cycle = 4, l_cycle = 8, device='cpu'):
        super().__init__()
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        
        self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(context_length, hidden_size)
        self.low = nn.GRUCell(input_size=hidden_size*4, hidden_size=hidden_size*4, device=device,)
        self.high = nn.GRUCell(input_size=hidden_size*4, hidden_size=hidden_size*4, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size*4, hidden_size*2),
            nn.ReLU(),
            nn.Linear(hidden_size*2, output_size)
        )
        
    def forward(self, tokens):
        
        token_embs = self.token_embed(tokens)
        pos_embs = self.pos_embed(torch.arange(0, tokens.shape[-1]).to(tokens.device))
        
        embs = token_embs+pos_embs
        embs = embs.view(tokens.shape[0], -1)
        # print(embs.shape)
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((tokens.shape[0],embs.shape[-1]))
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out
    

model = HRM(32, 8, 4, 8)

x = torch.randint(0,8, (1,4))
out = model(x)
print(out.shape)


at 0
at 4
at 8
at 12
at 16
at 20
at 24
at 28
torch.Size([1, 8])


In [None]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    
    def __init__(self, in_channels=1, embed_dim=512, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

    def forward(self, x):
        out = self.proj(x)
        out = out.flatten(2)
        return out

class HRMVision(nn.Module):
    
    def __init__(self, output_size,in_channels=4, sequence_length = 16, patch_size=64, embed_dim=16, h_cycle = 4, l_cycle = 8, device='cpu'):
        super().__init__()
        
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.context_length =  16
        self.patchify = PatchEmbedding(in_channels, sequence_length, patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, embed_dim)
        self.low = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device,)
        self.high = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim*embed_dim, embed_dim*embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim*embed_dim, output_size)
        )
        
    def forward(self, image):
        
        token_embs = self.patchify(image)
        
        pos_embs = self.pos_embed(torch.arange(0, self.context_length).to(image.device))
        embs = token_embs+pos_embs
        embs = embs.view(image.shape[0], -1)
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((image.shape[0],embs.shape[-1]))
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out

model = HRMVision(output_size=10, in_channels=1)

x = torch.randn((1,1, 256, 256))
out = model(x)
print(out.shape)


torch.Size([1, 10])


In [10]:
model.named_modules

<bound method Module.named_modules of HRMVision(
  (patchify): PatchEmbedding(
    (proj): Conv2d(1, 16, kernel_size=(64, 64), stride=(64, 64))
  )
  (pos_embed): Embedding(16, 16)
  (low): GRUCell(256, 256)
  (high): GRUCell(256, 256)
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=10, bias=True)
  )
)>