In [1]:
import torch

In [2]:
import pytorch_lightning as L

In [3]:
from data_loading import load_data, get_loaders, get_probe_data, get_probe_loaders, make_dataset
from transformer_predictor import TransformerPredictor, train, test

In [5]:
dataset = load_data('Data/balanced_with_no_count.csv')
# dataset = load_data('brackets.csv')

In [6]:
BATCH_SIZE = 64
train_loader, val_loader, test_loader, train_data, val_data, test_data = get_loaders(dataset, batch_size=BATCH_SIZE, return_data=True)

In [7]:
print(len(train_data), len(val_data), len(test_data))

24000 4000 12000


In [None]:
model = TransformerPredictor(
    input_dim=4,
    model_dim=128,
    num_classes=2,
    num_heads=2,
    num_layers=1,
    lr=1e-3,
    warmup=100,
    max_iters=1000,
)
trainer = L.Trainer(max_epochs=10, devices=1)

In [None]:
res,model, trainer, train_outbeddings, val_outbeddings = train(model, trainer, train_loader, val_loader)

In [None]:
# res, train_outbeddings = test(model, trainer, train_loader)
# res, val_outbeddings = test(model, trainer, val_loader)

In [None]:
res, test_outbeddings, _ = test(model, trainer, test_loader)

In [None]:
len(train_outbeddings), len(val_outbeddings), len(test_outbeddings)

In [None]:
import pandas as pd

no_count_df = pd.read_csv('new_test_data.csv')
no_count_df.head()

In [None]:
# Convert all count values to absolute in no_count_df

counts = no_count_df['count'].values
no_count_df['count'] = [abs(int(count)) for count in counts]

In [None]:
df_by_count = {}
for count in no_count_df['count'].unique():
    df_by_count[count] = no_count_df[no_count_df['count'] == count]

In [None]:
from test_data_loading import TestBracketDataset
from torch.utils.data import DataLoader



In [None]:
no_count_df['count'].unique()

In [None]:
maps_by_count = {}
results = []
mlp_logits = []
for count in no_count_df['count'].unique():
    avg_attn_map = torch.zeros((2, 513, 513))
    no_count_dataset = TestBracketDataset(df_by_count[count])
    no_count_loader = DataLoader(no_count_dataset, batch_size=BATCH_SIZE, num_workers=4)
    res, no_count_outbeddings, maps, logits = test(model, trainer, no_count_loader, attn=1)
    results.append(res)
    mlp_logits.append(logits)
    for i in range(len(maps)):
        for j in range(len(maps[i])):
            avg_attn_map += maps[i][j]
    avg_attn_map /= 200
    maps_by_count[count] = avg_attn_map

In [None]:
maps_by_count[0].shape

In [None]:
print(len(results))

In [None]:
# Save maps_by_count as a json

import json


for count in maps_by_count:
    maps_by_count_json = {}
    maps_by_count_json[str(count)] = maps_by_count[count].detach().numpy().tolist()
    with open(f'jsons/maps_by_count_{count}.json', 'w') as f:
        json.dump(maps_by_count_json, f)

In [None]:
acc = []
for i in range(len(results)):
    acc.append(results[i][0]['test_acc'])


In [None]:
# Plot accuracies vs count

import matplotlib.pyplot as plt

plt.plot(no_count_df['count'].unique(), acc)
plt.xlabel('Count')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Count')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming `maps_by_count` is a dictionary containing attention maps for each count
for count, attention_map in maps_by_count.items():
    # Plot the attention map with sns heatmap
    plt.figure(figsize=(40, 40))
    sns.heatmap(attention_map[1].detach().numpy())
    plt.title(f'Attention map for count={count}')
    plt.show()



In [None]:
import torch
from torch.utils.data import DataLoader
from data_loading import BracketDataset
from main import Logger

logger = Logger()

def get_stack_depths(loader: DataLoader, stack_depths=[5, 15]):
    # # stack_data = torch.where(torch.isin(data, torch.tensor(stack_depths)))[0]
    # # get the indices of the data that have the stack depth in stack_depths
    # print(len(data))
    # stack_indices = [i for i, x in enumerate(data) if x['sd'] in stack_depths]

    # # if stack_indices is empty, return None
    # if len(stack_indices) == 0:
    #     return None, 0
    
    # # create a new dataset with only the stack depths in stack_depths
    # stack_data = torch.utils.data.Subset(data, stack_indices)
    
    # loader = DataLoader(stack_data, batch_size=BATCH_SIZE)
    
    # return loader, len(stack_indices)
    stack_data = []
    for batch in loader:
        # print(batch)
        for i, x in enumerate(batch['sd']):
            # if type(x) is not dict:
            #     print(x)
            if x in stack_depths:
                sample = {
                    'x': batch['x'][i],
                    'y': batch['y'][i],
                    'sd': x,
                    'eos': batch['eos'][i]
                }
                stack_data.append(sample)
                                
    if len(stack_data) == 0:
        return None, 0
    
    stack_dataset = BracketDataset(list=stack_data)
    
    stack_loader = DataLoader(stack_data, batch_size=BATCH_SIZE)
    
    return stack_loader, len(stack_data)
    


for i in range(10, 101, 10):
    stack_loader, len_stack = get_stack_depths(test_loader, stack_depths=[i])
    if stack_loader is None:
        continue
    res, stack_outbeddings = test(model, trainer, stack_loader)
    logger.log(f"Stack depth {i}, Support: {len_stack}, Accuracy: {res}")
    # print(f"Stack depth {i}, Support: {len_stack}, Accuracy: {res}")
    # break

In [None]:
X_probe_train, y_probe_train = get_probe_data(train_outbeddings, stack_depths=[15, 25])
X_probe_val, y_probe_val = get_probe_data(val_outbeddings, stack_depths=[15, 25])
X_probe_test, y_probe_test = get_probe_data(test_outbeddings, stack_depths=[15, 25])

train_dataset = make_dataset(X_probe_train, y_probe_train)
val_dataset = make_dataset(X_probe_val, y_probe_val)
test_dataset = make_dataset(X_probe_test, y_probe_test)

train_probe_loader, val_probe_loader, test_probe_loader = get_probe_loaders(train_dataset, val_dataset, test_dataset, batch_size=64)

In [None]:
from probe import Probe, probe_all_models

In [None]:
probe = Probe(model_name='lr')
# probe.fit(train_probe_loader, val_probe_loader)
# probe.probe(X_probe_train.cpu().detach().numpy(), y_probe_train.cpu().detach().numpy(), X_probe_test.cpu().detach().numpy(), y_probe_test.cpu().detach().numpy())
probe_all_models(X_probe_train.cpu().detach().numpy(), y_probe_train.cpu().detach().numpy(), X_probe_test.cpu().detach().numpy(), y_probe_test.cpu().detach().numpy())

In [None]:
probe.plot_decision_boundary(X_probe_test.cpu().detach().numpy(), y_probe_test.cpu().detach().numpy())