In [1]:
import torch
import torch.nn as nn

In [2]:
import wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
from dataclasses import dataclass


@dataclass 
class ModelArgs:
    device = 'cuda'
    no_of_neurons = 128
    block_size = 64
    batch_size = 16
    dropout = 0.1
    epoch = 50
    max_lr = 1e-4
    embedding_dims: int = 784

In [4]:
torch.set_default_device(ModelArgs.device)

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader


torch.manual_seed(0)


num_samples = 10000 
seq_length = ModelArgs.block_size  
device = ModelArgs.device  


t = torch.linspace(0, 100, num_samples + seq_length, device=device)
data = torch.sin(t) + 0.1 * torch.randn_like(t)  


X_tensor = torch.stack([data[i:i+seq_length] for i in range(num_samples)])
y_tensor = data[seq_length:]  # Next value prediction

train_size = int(0.8 * num_samples)

X_train, y_train = X_tensor[:train_size], y_tensor[:train_size]  
X_val, y_val = X_tensor[train_size:], y_tensor[train_size:]  


class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


train_dataset = TimeSeriesDataset(X_train, y_train)
val_dataset = TimeSeriesDataset(X_val, y_val)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


generator = torch.Generator(device=device)


train_loader = DataLoader(
    train_dataset,
    batch_size=ModelArgs.batch_size,
    shuffle=True,  
    generator=generator,  
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    drop_last=True,
    batch_size=ModelArgs.batch_size,
    shuffle=False, 
)



In [6]:
X_tensor

tensor([[-0.0925, -0.0326, -0.2445,  ...,  0.5818,  0.6252,  0.4777],
        [-0.0326, -0.2445,  0.0443,  ...,  0.6252,  0.4777,  0.5906],
        [-0.2445,  0.0443,  0.0277,  ...,  0.4777,  0.5906,  0.5046],
        ...,
        [-0.9308, -0.9053, -0.9149,  ..., -0.4494, -0.6120, -0.6442],
        [-0.9053, -0.9149, -0.7448,  ..., -0.6120, -0.6442, -0.4798],
        [-0.9149, -0.7448, -1.0175,  ..., -0.6442, -0.4798, -0.3828]],
       device='cuda:0')

In [7]:
y_tensor

tensor([ 0.5906,  0.5046,  0.7054,  ..., -0.4798, -0.3828, -0.4920],
       device='cuda:0')

In [8]:
class ResetGate(torch.nn.Module):
    def __init__(self, device, no_of_neurons):
        super().__init__()
        self.linear = nn.Linear(in_features=ModelArgs.no_of_neurons + 1, out_features=no_of_neurons, device=device, dtype=torch.float32)
   
    def forward(self, x, ht_1):
        # print("Reset: ", x.shape)
        # print("Reset: ", ht_1.shape)
        x = torch.cat([x, ht_1], dim=-1)
        out = torch.nn.functional.sigmoid(self.linear(x))
        return out

In [9]:
class UpdateGate(nn.Module):
    def __init__(self, device, no_of_neurons) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features=ModelArgs.no_of_neurons + 1, out_features=no_of_neurons, device=device, dtype=torch.float32)
    def forward(self, x, ht_1):
        x = torch.cat([x, ht_1], dim=-1)
        out = torch.nn.functional.sigmoid(self.linear(x))
        return out

In [10]:
x = torch.randn(16,64)
x2 = torch.randn(16, 128)
torch.cat([x, x2], dim=-1).shape

torch.Size([16, 192])

In [11]:

class GRUBlock(nn.Module):
    def __init__(self, device, no_of_neurons) -> None:
        super().__init__()
        self.update = UpdateGate(device=device, no_of_neurons=no_of_neurons)
        self.reset = ResetGate(device=device, no_of_neurons=no_of_neurons)
        self.candidate = nn.Linear(in_features=ModelArgs.no_of_neurons, out_features=no_of_neurons, device=device, dtype=torch.float32)
        
    def forward(self, x, ht_1, output=None):
        
        # ht = ht_1
        # print(x.shape)
        seq_len = x.shape[1]
        
        # print("LSTM: ",x.shape)
        # print("LSTM: ", self.ht_1.shape)
        # print("LSTM: ", self.ct_1.shape)
        # ht = None
        # combined = None
        if(output is None):
            output = []
            for t in range(seq_len):
                
                xt = x[:, t].unsqueeze(-1)
                
                # print(x.shape)
                # print(xt.shape)
                # print(ht_1.shape)
                reset_out = self.reset(xt, ht_1) #It is not updating the states correctly since using the same hidden states
                # print(xt)
                ht_1_candidate = nn.functional.tanh(self.candidate(reset_out))
                update_out = self.update(xt, ht_1)
                ht = ((1- update_out) * ht_1) + (update_out * ht_1_candidate)
                ht_1 = ht
                output.append(torch.tensor(ht_1))
        
            return ht_1, torch.stack(output, dim=1)
        
            # print("Current hidden state: ", ht)
        elif(output is not None and len(output) != 0):
            for t in range(seq_len):
                
                xt = x[:, t].unsqueeze(-1)
                xt = x[:, t].unsqueeze(-1)
                # print(x.shape)
                # print(xt.shape)
                # print(ht_1.shape)
                ht_1 = output[: ,t]
                reset_out = self.reset(xt, ht_1) #It is not updating the states correctly since using the same hidden states
                # print(xt)
                ht_1_candidate = nn.functional.tanh(self.candidate(reset_out))
                update_out = self.update(xt, ht_1)
                ht = ((1- update_out) * ht_1) + (update_out * ht_1_candidate)
                ht_1 = ht
                    # print("Current hidden state: ", ht)
            return ht_1, None

In [12]:
class GRU(nn.Module):
    def __init__(self, device, no_of_neurons, out_features):
        super().__init__()
        self.block1 = GRUBlock(device=device, no_of_neurons=no_of_neurons)
        self.block2 = GRUBlock(device=device, no_of_neurons=no_of_neurons)
        self.ht_1 = torch.randn(ModelArgs.batch_size, no_of_neurons, device=device, requires_grad=True, dtype=torch.float32)
        # self.ct_1 = torch.randn(ModelArgs.batch_size, no_of_neurons,device=device, requires_grad=True, dtype=torch.float32)
        self.output = nn.Linear(in_features=ModelArgs.no_of_neurons, out_features=out_features, device=device, dtype=torch.float32)
        self.dropout = nn.Dropout(p=ModelArgs.dropout)
        # self.embedding = nn.Embedding()
        
    def forward(self, x):
        # x = 
        # print("LSTM: ",x.shape)
        # print("LSTM: ", self.ht_1.shape)
        # print("LSTM: ", self.ct_1.shape)
        ht, hidden_states = self.block1(x, self.ht_1)
        # print("Ouputs: ", hidden_states.shape)
        # print(ht.shape)
        # print(ct.shape)
        ht, hidden_states = self.block2(x, ht, hidden_states)
        ht = self.dropout(ht)
        # print("After: ", ht.shape)
        out = nn.functional.sigmoid(self.output(ht))
        return out

In [13]:
model = GRU(device=ModelArgs.device, no_of_neurons=ModelArgs.no_of_neurons, out_features=1)
model = model.to(ModelArgs.device)

In [14]:
x = torch.randint(0, 100, (ModelArgs.batch_size,ModelArgs.block_size)) 
x.T[: ,0]

tensor([20,  6, 73, 64, 93, 14, 97, 33, 31, 92, 32, 73, 98, 32, 50, 71, 58, 59,
        52, 99, 39,  6, 97, 98, 60, 92, 16, 30, 15, 86, 15, 95, 24, 89, 55, 20,
        52, 18, 69, 63, 80, 25, 59, 26, 48, 45, 75, 40, 69, 58, 77,  8, 58,  6,
        70, 24, 40, 13, 51, 14, 83, 95, 83, 95], device='cuda:0')

In [15]:
!pip install torchinfo

from torchinfo import summary

x = torch.randint(0, 100, (ModelArgs.batch_size,ModelArgs.block_size))  # Random integer between 0 and 100
x = x.to(ModelArgs.device)
summary(model=model,
        input_data=x,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])




  return func(*args, **kwargs)


Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
GRU (GRU)                                [16, 64]             [16, 1]              --                   True
├─GRUBlock (block1)                      [16, 64]             [16, 128]            --                   True
│    └─ResetGate (reset)                 [16, 1]              [16, 128]            --                   True
│    │    └─Linear (linear)              [16, 129]            [16, 128]            16,640               True
│    └─Linear (candidate)                [16, 128]            [16, 128]            16,512               True
│    └─UpdateGate (update)               [16, 1]              [16, 128]            --                   True
│    │    └─Linear (linear)              [16, 129]            [16, 128]            16,640               True
│    └─ResetGate (reset)                 [16, 1]              [16, 128]            (recursive)          True
│    │    └─Li

In [16]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=ModelArgs.max_lr)

In [None]:
model.train()
train_losses =  torch.zeros(len(train_loader))
val_losses = torch.zeros(len(val_loader))
wandb.init(
    project='GRU-From-Scratch'
)
for epoch in range(ModelArgs.epoch):
    
    count = 0
    for X, y in train_loader:
        y_pred = model(X)
        # print(y_pred.shape)
        loss = criterion(y_pred, y)
        train_losses[count] = loss.item()
        # print("Loss: ", loss.item())
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        count += 1
        
    #     wandb.log({
    #   "Train Loss": loss.item(),
    #   "Val Loss": loss.item(),
    #   "step": count  
    # })
    # count = 0
    print("Epoch: ", epoch, "|", "Step: ", count, "|", "Train Loss: ", loss.item())
    model.eval()
    count = 0
    
    for X, y in val_loader:
        y_pred = model(X)
        # print(y_pred.shape)
        loss = criterion(y_pred, y)
        
        # print("Loss: ", loss.item())
        val_losses[count] = loss.item()
        
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        count += 1
    #     wandb.log({
    #   "Train Loss": loss.item(),
    #   "Val Loss": loss.item(),
    #   "step": count  
    # })
        print("Epoch: ", epoch, "|", "Step: ", count, "|", "Val Loss: ", loss.item())
    model.train()
    wandb.log({
      "Train Loss": train_losses.mean(),
      "Val Loss": val_losses.mean(),
      "epoch": epoch  
    })
    print("Epoch: ", epoch, "|", "Train Loss: ", train_losses.mean(),  "|", "Val Loss: ", val_losses.mean())


[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Epoch:  0 | Step:  500 | Train Loss:  0.5754755735397339
Epoch:  0 | Step:  1 | Val Loss:  1.0404133796691895
Epoch:  0 | Step:  2 | Val Loss:  0.9872232675552368
Epoch:  0 | Step:  3 | Val Loss:  0.9304396510124207
Epoch:  0 | Step:  4 | Val Loss:  0.7334321737289429
Epoch:  0 | Step:  5 | Val Loss:  0.6048568487167358
Epoch:  0 | Step:  6 | Val Loss:  0.45914244651794434
Epoch:  0 | Step:  7 | Val Loss:  0.26499342918395996
Epoch:  0 | Step:  8 | Val Loss:  0.151251882314682
Epoch:  0 | Step:  9 | Val Loss:  0.06759858131408691
Epoch:  0 | Step:  10 | Val Loss:  0.025497961789369583
Epoch:  0 | Step:  11 | Val Loss:  0.012865077704191208
Epoch:  0 | Step:  12 | Val Loss:  0.04962214455008507
Epoch:  0 | Step:  13 | Val Loss:  0.16061609983444214
Epoch:  0 | Step:  14 | Val Loss:  0.2801150679588318
Epoch:  0 | Step:  15 | Val Loss:  0.4256272614002228
Epoch:  0 | Step:  16 | Val Loss:  0.5464664101600647
Epoch:  0 | Step:  17 | Val Loss:  0.757870614528656
Epoch:  0 | Step:  18 | Val

  return func(*args, **kwargs)
  return func(*args, **kwargs)


Epoch:  1 | Step:  500 | Train Loss:  0.6364208459854126
Epoch:  1 | Step:  1 | Val Loss:  1.005723476409912
Epoch:  1 | Step:  2 | Val Loss:  0.9535362720489502
Epoch:  1 | Step:  3 | Val Loss:  0.8977890014648438
Epoch:  1 | Step:  4 | Val Loss:  0.7046825289726257
Epoch:  1 | Step:  5 | Val Loss:  0.5786488652229309
Epoch:  1 | Step:  6 | Val Loss:  0.43634361028671265
Epoch:  1 | Step:  7 | Val Loss:  0.24798494577407837
Epoch:  1 | Step:  8 | Val Loss:  0.13851964473724365
Epoch:  1 | Step:  9 | Val Loss:  0.059595949947834015
Epoch:  1 | Step:  10 | Val Loss:  0.022194232791662216
Epoch:  1 | Step:  11 | Val Loss:  0.014941531233489513
Epoch:  1 | Step:  12 | Val Loss:  0.05630975216627121
Epoch:  1 | Step:  13 | Val Loss:  0.17382831871509552
Epoch:  1 | Step:  14 | Val Loss:  0.2979313135147095
Epoch:  1 | Step:  15 | Val Loss:  0.44775378704071045
Epoch:  1 | Step:  16 | Val Loss:  0.5714903473854065
Epoch:  1 | Step:  17 | Val Loss:  0.7871824502944946
Epoch:  1 | Step:  18 |