In [1]:
from data import TrainDataset, TestDataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from models import SAN

In [23]:
# helper function to calculate accuracy
def calculate_accuracy(outputs, targets):
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == targets).sum().item()
    total = targets.size(0)
    return correct / total



# define train and validation function 
def train(model, train_dataloader, 
                IC_middle_first_dataloader,
                IC_first_middle_dataloader,
                IC_first_last_dataloader,
                IC_last_first_dataloader,
                IC_middle_last_dataloader,
                IC_last_middle_dataloader,
                criterion, optimizer, 
                device, num_iter=150000, rec_freq=1):
    
    model.train()
    for i, (inputs, targets) in enumerate(train_dataloader):
        if i >= num_iter:
            break

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        if i % rec_freq == 0:
            IC_middle_first_loss, IC_middle_first_acc = validate(model, IC_middle_first_dataloader, criterion, device)
            IC_first_middle_loss, IC_first_middle_acc = validate(model, IC_first_middle_dataloader, criterion, device)
            IC_first_last_loss, IC_first_last_acc = validate(model, IC_first_last_dataloader, criterion, device)
            IC_last_first_loss, IC_last_first_acc = validate(model, IC_last_first_dataloader, criterion, device)
            IC_middle_last_loss, IC_middle_last_acc = validate(model, IC_middle_last_dataloader, criterion, device)
            IC_last_middle_loss, IC_last_middle_acc = validate(model, IC_last_middle_dataloader, criterion, device)
        
            print(f"Iter {i}:"
                  f"IC First Middle Acc - Middle First Acc: {IC_first_middle_acc - IC_middle_first_acc:.4f}, "
                  f"IC First Last Acc - IC Last First Acc: {IC_first_last_acc - IC_last_first_acc:.4f}, "
                  f"IC Middle Last Acc - IC Last Middle Acc: {IC_middle_last_acc - IC_last_middle_acc:.4f}, ")
    

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item()
            running_accuracy += calculate_accuracy(outputs, targets)
    return running_loss / len(dataloader), running_accuracy / len(dataloader)

In [4]:
# specify the training dataset arguments: no position bias
N = 8
num_classes = 2048
dim_features = 64
B = 4
eps = 0.75
pos_bias = False
index = [0]
test_size = 10000



# specify the training arguments
bs = 128
lr = 1e-3
wd = 1e-6 




In [5]:
# Initialize the datasets; note that this takes quite a while to run
train_dataset = TrainDataset(N=N, K=num_classes, D=dim_features, B=B, eps=eps, pos_bias=pos_bias, index=index)
IC_first_middle_dataset = TestDataset(num_seqs=test_size, test_type='IC_first_middle', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)
IC_middle_first_dataset = TestDataset(num_seqs=test_size, test_type='IC_middle_first', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)
IC_first_last_dataset = TestDataset(num_seqs=test_size, test_type='IC_first_last', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)
IC_last_first_dataset = TestDataset(num_seqs=test_size, test_type='IC_last_first', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)
IC_middle_last_dataset = TestDataset(num_seqs=test_size, test_type='IC_middle_last', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)
IC_last_middle_dataset = TestDataset(num_seqs=test_size, test_type='IC_last_middle', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
IC_first_middle_dataloader = DataLoader(IC_first_middle_dataset, batch_size=bs, shuffle=False)
IC_middle_first_dataloader = DataLoader(IC_middle_first_dataset, batch_size=bs, shuffle=False)
IC_first_last_dataloader = DataLoader(IC_first_last_dataset, batch_size=bs, shuffle=False)
IC_last_first_dataloader = DataLoader(IC_last_first_dataset, batch_size=bs, shuffle=False)
IC_middle_last_dataloader = DataLoader(IC_middle_last_dataset, batch_size=bs, shuffle=False)
IC_last_middle_dataloader = DataLoader(IC_last_middle_dataset, batch_size=bs, shuffle=False)

In [7]:


# Initialize model, criterion, and optimizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()


In [None]:


mask_type = "causal"    # Options: "causal", "decay"
gamma = 1   # Options: 1 for causal, 0.8 for decay
num_attn_layers = 2  # Options: 2 or 6


model = SAN(in_channels=dim_features, hidden_channels=dim_features, out_channels=32, mask_type=mask_type, gamma=gamma, num_attn_layers=num_attn_layers).to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

In [54]:
# train! 

train(model, 
    train_dataloader, 
    IC_middle_first_dataloader, 
    IC_first_middle_dataloader,
    IC_first_last_dataloader,
    IC_last_first_dataloader,
    IC_middle_last_dataloader,
    IC_last_middle_dataloader,
    criterion, 
    optimizer, 
    device, 
    num_iter=1000, rec_freq=100)

Iter 0:IC First Middle Acc - Middle First Acc: 0.0025, IC First Last Acc - IC Last First Acc: 0.0031, IC Middle Last Acc - IC Last Middle Acc: 0.0046, 
Iter 100:IC First Middle Acc - Middle First Acc: 0.0949, IC First Last Acc - IC Last First Acc: 0.0650, IC Middle Last Acc - IC Last Middle Acc: 0.0037, 
Iter 200:IC First Middle Acc - Middle First Acc: 0.0874, IC First Last Acc - IC Last First Acc: 0.1411, IC Middle Last Acc - IC Last Middle Acc: 0.0160, 
Iter 300:IC First Middle Acc - Middle First Acc: 0.1311, IC First Last Acc - IC Last First Acc: 0.1279, IC Middle Last Acc - IC Last Middle Acc: 0.0111, 
Iter 400:IC First Middle Acc - Middle First Acc: 0.1420, IC First Last Acc - IC Last First Acc: 0.1346, IC Middle Last Acc - IC Last Middle Acc: 0.0148, 
Iter 500:IC First Middle Acc - Middle First Acc: 0.1224, IC First Last Acc - IC Last First Acc: 0.1071, IC Middle Last Acc - IC Last Middle Acc: 0.0122, 
Iter 600:IC First Middle Acc - Middle First Acc: 0.1361, IC First Last Acc - I