# Examine the codes, and try to learn a model 

In [10]:
import os
import torch
import torch.nn as nn
import numpy as np

In [2]:
from datasets import load_from_disk

In [3]:
ds = load_from_disk("/home/gagan/ek_processed_data/vqvae_encoded_songbird_data/model_2024-02-28_firm_galaxy/")

In [4]:
len(ds)

488517

In [6]:
ds

Dataset({
    features: ['codes', 'bird_name', 'days_post_hatch', 'recording_date'],
    num_rows: 488517
})

In [20]:
codes = torch.Tensor(ds[0]["codes"])
codes.shape

torch.Size([10, 64, 8])

In [None]:
codes[0:2,0]

In [29]:
codes = codes.flatten()
codes.shape

torch.Size([5120])

In [25]:
codes = codes.to(torch.int32)

# Recurrent network

In [17]:
class Rnet(torch.nn.Module):

    def __init__(self, num_embeddings: int = 512,
                 embedding_dim: int = 512, 
                 num_layers: int = 2,
                 hidden_size: int = 512,
                 dropout: float = 0.1,
                ):
        super().__init__()

        dropout_ = 0.0 if num_layers == 0 else dropout
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.net = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size,
        bidirectional=False, batch_first=True, num_layers=num_layers, dropout=dropout_)
        self.norm = nn.LayerNorm(hidden_size)
        self.output_layer = nn.Linear(hidden_size, num_embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x has shape (B, T, C)
        x = self.embedding(x)
        o, _ = self.net(x)
        o = self.norm(o)
        return self.output_layer(o)

In [18]:
net = Rnet()

In [19]:
net

Rnet(
  (embedding): Embedding(512, 512)
  (net): GRU(512, 512, num_layers=2, batch_first=True, dropout=0.1)
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (output_layer): Linear(in_features=512, out_features=512, bias=True)
)

In [26]:
z = net.embedding(codes)

In [27]:
z.shape

torch.Size([1, 10, 512, 512])

In [11]:
xx = torch.rand(1, 3, 32, 32)

In [12]:
xx.shape

torch.Size([1, 3, 32, 32])