In [4]:
import torch
import torch.nn as nn

In [5]:
from dataclasses import dataclass


@dataclass 
class ModelArgs:
    device = 'cuda'
    no_of_neurons = 16
    block_size = 16
    batch_size = 16
    dropout = 0.1
    epoch = 50
    max_lr = 1e-4
    embedding_dims: int = 784

In [6]:
import wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
torch.set_default_device(ModelArgs.device)

In [29]:
import torch
from torch.utils.data import Dataset, DataLoader


torch.manual_seed(0)


num_samples = 10000 
seq_length = ModelArgs.block_size  
device = ModelArgs.device  


t = torch.linspace(0, 100, num_samples + seq_length, device=device)
data = torch.sin(t) 
# data = t

X_tensor = torch.stack([data[i:i+seq_length] for i in range(num_samples)])
y_tensor = data[seq_length:]  # Next value prediction

train_size = int(0.8 * num_samples)

X_train, y_train = X_tensor[:train_size], y_tensor[:train_size]  
X_val, y_val = X_tensor[train_size:], y_tensor[train_size:]  


class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


train_dataset = TimeSeriesDataset(X_train, y_train)
val_dataset = TimeSeriesDataset(X_val, y_val)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


generator = torch.Generator(device=device)


train_loader = DataLoader(
    train_dataset,
    batch_size=ModelArgs.batch_size,
    shuffle=True,  
    generator=generator,  
    # drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    generator=generator, 
    # drop_last=True,
    batch_size=ModelArgs.batch_size,
    shuffle=True, 
)



In [44]:
class RNNCell(nn.Module):
    def __init__(self, device, no_of_neurons):
        super().__init__()
        
        self.linear_layer = nn.Linear(in_features=ModelArgs.block_size + 1, out_features=no_of_neurons, device=ModelArgs.device)
        
    def forward(self, x, ht_1):
        x = self.linear_layer(torch.cat([x, ht_1], dim=-1))
        ht = torch.nn.functional.sigmoid(x)
        return ht
        

In [49]:
class RNNLayer(nn.Module):
    def __init__(self, device, no_of_neurons):
        super().__init__()
        
        self.rnn_layer = RNNCell(device=device, no_of_neurons=no_of_neurons)
        self.linear_layer = nn.Linear(in_features=ModelArgs.block_size, out_features=no_of_neurons, device=ModelArgs.device)
        
    def forward(self, x):

        ht_1 = torch.zeros((ModelArgs.batch_size, ModelArgs.block_size), device=ModelArgs.device, requires_grad=True, dtype=torch.float32)
        
        seq_len = x.shape[1]
        
        for t in range(seq_len):
            xt = x[:, t]
            xt = xt.unsqueeze(-1)
            ht = self.rnn_layer(xt, ht_1)
            ht_1 = ht
            
        return ht_1

In [50]:
class RNN(nn.Module):
    def __init__(self, device, no_of_neurons, out_features):
        super().__init__()
        self.rnn = RNNLayer(device=device, no_of_neurons=no_of_neurons)

        self.output = nn.Linear(in_features=ModelArgs.no_of_neurons, out_features=out_features, device=device, dtype=torch.float32)
        self.dropout = nn.Dropout(p=ModelArgs.dropout)
        
    def forward(self, x):
        
        ht = self.rnn(x)
        out = self.output(ht)
        out = self.dropout(out)
        return ht
        

In [51]:
# model = GRU(device=ModelArgs.device, no_of_neurons=ModelArgs.no_of_neurons, out_features=1)
model = RNN(device=ModelArgs.device, no_of_neurons=ModelArgs.no_of_neurons, out_features=1)
model = model.to(ModelArgs.device)

In [52]:
!pip install torchinfo

from torchinfo import summary

x = torch.randint(0, 100, (ModelArgs.batch_size,ModelArgs.block_size))  # Random integer between 0 and 100
x = x.to(ModelArgs.device)
summary(model=model,
        input_data=x,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])




Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
RNN (RNN)                                [16, 16]             [16, 16]             --                   True
├─RNNLayer (rnn)                         [16, 16]             [16, 16]             272                  True
│    └─RNNCell (rnn_layer)               [16, 1]              [16, 16]             --                   True
│    │    └─Linear (linear_layer)        [16, 17]             [16, 16]             288                  True
│    └─RNNCell (rnn_layer)               [16, 1]              [16, 16]             (recursive)          True
│    │    └─Linear (linear_layer)        [16, 17]             [16, 16]             (recursive)          True
│    └─RNNCell (rnn_layer)               [16, 1]              [16, 16]             (recursive)          True
│    │    └─Linear (linear_layer)        [16, 17]             [16, 16]             (recursive)          True
│    └─RNNCell

In [53]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=ModelArgs.max_lr)

In [54]:
model.train()
train_losses =  torch.zeros(len(train_loader))
val_losses = torch.zeros(len(val_loader))
wandb.init(
    project='GRU-From-Scratch'
)
for epoch in range(ModelArgs.epoch):
    
    count = 0
    for X, y in train_loader:
        y_pred = model(X)
        # print(y_pred.shape)
        loss = criterion(y_pred, y)
        train_losses[count] = loss.item()
        # print("Loss: ", loss.item())
        
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # if(count != 0):
        #     total_norm_before = torch.norm(
        #             torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
        #         )

        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        #         # Compute gradient norms after clipping
        #     total_norm_after = torch.norm(
        #             torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
        #         )
                
        #     # if(device  == 0 and step !=0):
        #     print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
        #     print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        optimizer.step()
        count += 1
        
    #     wandb.log({
    #   "Train Loss": loss.item(),
    #   "Val Loss": loss.item(),
    #   "step": count  
    # })
    # count = 0
    print("Epoch: ", epoch, "|", "Step: ", count, "|", "Train Loss: ", loss.item())
    model.eval()
    count = 0
    
    for X, y in val_loader:
        y_pred = model(X)
        # print(y_pred.shape)
        loss = criterion(y_pred, y)
        
        # print("Loss: ", loss.item())
        val_losses[count] = loss.item()
        
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        count += 1
    #     wandb.log({
    #   "Train Loss": loss.item(),
    #   "Val Loss": loss.item(),
    #   "step": count  
    # })
        print("Epoch: ", epoch, "|", "Step: ", count, "|", "Val Loss: ", loss.item())
    model.train()
    wandb.log({
      "Train Loss": train_losses.mean(),
      "Val Loss": val_losses.mean(),
      "epoch": epoch  
    })
    print("Epoch: ", epoch, "|", "Train Loss: ", train_losses.mean(),  "|", "Val Loss: ", val_losses.mean())


[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011443493633285269, max=1.0…

  return func(*args, **kwargs)


Epoch:  0 | Step:  500 | Train Loss:  0.656163215637207
Epoch:  0 | Step:  1 | Val Loss:  0.6380040049552917
Epoch:  0 | Step:  2 | Val Loss:  0.6970006227493286
Epoch:  0 | Step:  3 | Val Loss:  0.7294517159461975
Epoch:  0 | Step:  4 | Val Loss:  0.9988241195678711
Epoch:  0 | Step:  5 | Val Loss:  0.6275670528411865
Epoch:  0 | Step:  6 | Val Loss:  0.4303639233112335
Epoch:  0 | Step:  7 | Val Loss:  0.3101731240749359
Epoch:  0 | Step:  8 | Val Loss:  0.8717292547225952
Epoch:  0 | Step:  9 | Val Loss:  0.7538192272186279
Epoch:  0 | Step:  10 | Val Loss:  0.766417920589447
Epoch:  0 | Step:  11 | Val Loss:  0.7333325743675232
Epoch:  0 | Step:  12 | Val Loss:  0.870871901512146
Epoch:  0 | Step:  13 | Val Loss:  0.7366973161697388
Epoch:  0 | Step:  14 | Val Loss:  0.3938369154930115
Epoch:  0 | Step:  15 | Val Loss:  0.6606581211090088
Epoch:  0 | Step:  16 | Val Loss:  0.8170782327651978
Epoch:  0 | Step:  17 | Val Loss:  0.6676695346832275
Epoch:  0 | Step:  18 | Val Loss:  0.

  return func(*args, **kwargs)


Epoch:  1 | Step:  500 | Train Loss:  0.26001042127609253
Epoch:  1 | Step:  1 | Val Loss:  0.8739071488380432
Epoch:  1 | Step:  2 | Val Loss:  1.198178768157959
Epoch:  1 | Step:  3 | Val Loss:  0.6544116735458374
Epoch:  1 | Step:  4 | Val Loss:  0.5835412740707397
Epoch:  1 | Step:  5 | Val Loss:  0.7006656527519226
Epoch:  1 | Step:  6 | Val Loss:  0.8122824430465698
Epoch:  1 | Step:  7 | Val Loss:  0.8781837821006775
Epoch:  1 | Step:  8 | Val Loss:  0.9911659955978394
Epoch:  1 | Step:  9 | Val Loss:  0.6607527732849121
Epoch:  1 | Step:  10 | Val Loss:  0.9287918210029602
Epoch:  1 | Step:  11 | Val Loss:  0.9807298183441162
Epoch:  1 | Step:  12 | Val Loss:  0.5087670087814331
Epoch:  1 | Step:  13 | Val Loss:  0.6317039728164673
Epoch:  1 | Step:  14 | Val Loss:  0.6480879783630371
Epoch:  1 | Step:  15 | Val Loss:  0.8496520519256592
Epoch:  1 | Step:  16 | Val Loss:  1.100818157196045
Epoch:  1 | Step:  17 | Val Loss:  0.6442431211471558
Epoch:  1 | Step:  18 | Val Loss:  