In [44]:
!pip install torch torchvision pytorch-lightning

/Users/akshitsinha3/.zshenv:.:10: no such file or directory: /Users/akshitsinha3/.cargo/env


In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset
import csv
import numpy as np

In [46]:
inp_list = []
out_list = []

with open('balanced_brackets.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader)
    data = list(reader)
    for row in data:
        inp_list.append(row[0])
        out = int(row[1]) > 0
        out_list.append(out)

In [47]:
max_len = max([len(x) for x in inp_list])
max_len

398

In [48]:
def pad_string(s):
    return s + '&' + '-' * (max_len - len(s))

In [49]:
encode_dict = {'(': 0, ')': 1, '-': 2, '&': 3}
decode_dict = {0: '(', 1: ')', 2: '-', 3: '&'}

def encode(s):
    return torch.tensor([encode_dict[c] for c in s], dtype=torch.float)

In [50]:
# train test split
torch.manual_seed(0)

n = len(inp_list)

indices = torch.randperm(n)

train_indices = indices[:int(0.7*n)]
val_indices = indices[int(0.7*n):int(0.85*n)]
test_indices = indices[int(0.85*n):]

X_train = ([encode(pad_string(inp_list[i])) for i in train_indices])
X_val = [encode(pad_string(inp_list[i])) for i in val_indices]
X_test = [encode(pad_string(inp_list[i])) for i in test_indices]

y_train = [torch.tensor(out_list[i]) for i in train_indices]
y_val = [torch.tensor(out_list[i]) for i in val_indices]
y_test = [torch.tensor(out_list[i]) for i in test_indices]



                                                                   

In [51]:
X_train = torch.stack(X_train)
X_val = torch.stack(X_val)
X_test = torch.stack(X_test)

y_train = torch.stack(y_train)
y_val = torch.stack(y_val)
y_test = torch.stack(y_test)

X_train

tensor([[0., 0., 0.,  ..., 2., 2., 2.],
        [0., 0., 0.,  ..., 2., 2., 2.],
        [0., 1., 0.,  ..., 2., 2., 2.],
        ...,
        [0., 0., 1.,  ..., 2., 2., 2.],
        [0., 1., 0.,  ..., 2., 2., 2.],
        [0., 0., 1.,  ..., 2., 2., 2.]])

In [70]:
class Transformer(pl.LightningModule):
    def __init__(self, input_dim, output_dim, num_heads, hidden_dim, num_layers, t_src_mask=None, v_src_mask=None):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.linear = nn.Linear(hidden_dim, output_dim)
        self.t_src_mask = t_src_mask
        self.v_src_mask = v_src_mask
        

    def forward(self, src, src_mask):
        # Ensure src is a tensor (in case of tuple input from the data loader)
        if isinstance(src, tuple):
            src = src[0]
        # Convert src to the appropriate type (e.g., Long or Int)
        src = src.type(torch.LongTensor).to(self.device)
        embedded_src = self.embedding(src)
        output = self.transformer_encoder(embedded_src, src_key_padding_mask=src_mask.T)
        output = self.linear(output[-1])  # Only using the output of the last time step
        return output

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        src_mask = self.t_src_mask[batch_idx].to(self.device)
        outputs = self(inputs, src_mask)
        loss = nn.MSELoss()(outputs, targets)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        src_mask = self.v_src_mask[batch_idx].to(self.device)
        outputs = self(inputs, src_mask)
        loss = nn.MSELoss()(outputs, targets)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        src_mask = self.v_src_mask[batch_idx].to(self.device)
        outputs = self(inputs, src_mask)
        loss = nn.MSELoss()(outputs, targets)
        self.log('test_loss', loss)
        return loss

In [53]:
def collate_fn(data):
    # Data is a list of tuples (input_sequence, target_sequence)
    inputs, targets = zip(*data)
    src_mask = torch.stack([torch.tensor([False if x == 2 else True for x in inp]) for inp in inputs])
    src_mask = src_mask.type(torch.BoolTensor)
    # convert src mask to 1d tensor
    src_mask = src_mask.view(src_mask.size(0), -1)
    print(src_mask.shape)
    return inputs, targets, src_mask

In [54]:
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

In [55]:
train_src_mask = []
val_src_mask = []

In [56]:
re_train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
re_val_loader = DataLoader(val_dataset, batch_size=128, collate_fn=collate_fn)

In [57]:
from tqdm import tqdm

In [58]:
for data in tqdm(re_train_loader):
    train_src_mask.append(data[2])
for data in tqdm(re_val_loader):
    val_src_mask.append(data[2])



torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])


100%|██████████| 218/218 [00:58<00:00,  3.70it/s]


torch.Size([84, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])




torch.Size([128, 399])


100%|██████████| 47/47 [00:11<00:00,  4.01it/s]

torch.Size([128, 399])
torch.Size([82, 399])





In [59]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

In [71]:
model = Transformer(input_dim=4, output_dim=3, num_heads=2, hidden_dim=128, num_layers=1, t_src_mask=train_src_mask, v_src_mask=val_src_mask)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name                | Type                    | Params
----------------------------------------------------------------
0 | embedding           | Embedding               | 512   
1 | encoder_layer       | TransformerEncoderLayer | 593 K 
2 | transformer_encoder | TransformerEncoder      | 593 K 
3 | linear              | Linear                  | 387   
----------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.748     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/akshitsinha3/Library/CloudStorage/OneDrive-InternationalInstituteofInformationTechnology/3-2/RSAI/Mechanistic_Interpretability/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (3) must match the size of tensor b (128) at non-singleton dimension 1

In [None]:
def test_model(model, test_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            outputs = model(inputs)
            loss = nn.MSELoss()(outputs, targets)
            total_loss += loss.item() * len(inputs)
    avg_loss = total_loss / len(test_loader.dataset)
    print(f"Test Loss: {avg_loss:.4f}")

# Assuming you have a test dataset 'test_dataset' and DataLoader 'test_loader'
test_dataset = TensorDataset(X_test, y_test)  # Replace X_test, y_test with your test data
test_loader = DataLoader(test_dataset, batch_size=16)

# Call the test_model function after training
test_model(model, test_loader)


In [None]:
def inference(model, input_sequence):
    model.eval()
    with torch.no_grad():
        input_sequence = input_sequence.unsqueeze(0)  # Add batch dimension
        output = model(input_sequence)
    return output.squeeze(0)  # Remove batch dimension from output

# Example input sequence (replace this with your own data)
input_sequence = torch.randn(10, 50)  # Example input sequence with shape (seq_len, input_dim)

# Perform inference
output = inference(model, input_sequence)
# take softmax of output
output = nn.Softmax(dim=0)(output)

print("Input Sequence:")
print(input_sequence.shape)
print("Output Sequence:")
print(output.shape)


In [None]:
output

In [None]:
# sum up the output tensor
output.sum()