In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
from model import MambaFull, generate_data

# Define model parameters and hyperparameters
class DotDict(dict):
    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self

args=DotDict()
args.bsz=100
args.d_model = 64
args.coord_dim = 2
args.nb_layers = 2
args.mlp_cls = nn.Identity #nn.Linear
args.city_count = 50
args.sequence_length = args.city_count + 1
args.deterministic = False #used for sampling from the model

In [2]:
seed = 402
torch.manual_seed(seed)
np.random.seed(seed)

model = MambaFull(args.d_model, args.city_count, args.nb_layers, args.coord_dim, args.mlp_cls)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [16]:
# Define a loss function
loss_fn = nn.CrossEntropyLoss()

# Define an optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# List to store loss values
loss_values = []
best_loss = float('inf')

# Training loop
for epoch in tqdm(range(1)):
    # Mask is used to prevent the model from choosing the same city twice
    mask = torch.ones(args.bsz, args.city_count).to(device)
    #Inputs will have size (bsz, seq_len, coord_dim)
    inputs = generate_data(device, args.bsz, args.city_count, args.coord_dim)
    # list that will contain Long tensors of shape (bsz,) that gives the idx of the cities chosen at time t
    tours = []
    # list that will contain Float tensors of shape (bsz,) that gives the log probs of the choices made at time t
    sumLogProbOfActions = []
    #Construct tour recursively
    for i in range(args.city_count):
        #print(i)
        outputs = model(inputs)[:,-1,:]
        #print(outputs[0])
        outputs = outputs.masked_fill_(mask == 0, -float('inf'))
        #print(outputs[0])
        outputs = nn.Softmax(dim=1)(outputs)
        #print(outputs[0])
        #if args.deterministic:
        next_city = torch.argmax(outputs, dim=1)
        #print(next_city.shape)
        #else:
        #    next_city = Categorical(outputs).sample()
        #print(next_city[0])
        tours.append(next_city)
        sumLogProbOfActions.append(torch.log(outputs[torch.arange(args.bsz), next_city]) )
        mask[torch.arange(args.bsz), next_city] = 0
        inputs = torch.cat((inputs, inputs[torch.arange(args.bsz), next_city, :].unsqueeze(1)), dim=1)
        args.sequence_length += 1
    print(len(sumLogProbOfActions))
    tours = torch.stack(tours, dim=1)
    sumLogProbOfActions = torch.stack(sumLogProbOfActions, dim=1).sum(dim=1)

    
    '''loss = loss_fn(outputs.reshape(-1,7), targets.reshape(-1))

    # Backward pass: compute the gradients of the loss with respect to the model's parameters
    loss.backward()

    # Update the model's parameters
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()

    # Save checkpoint every 10,000 epochs
    if loss.item() < best_loss:
        best_loss = loss.item()
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item()
        }
        torch.save(checkpoint, 'best_checkpoint.pt')'''
    
# Plot loss values
'''plt.plot(loss_values)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()'''

100%|██████████| 1/1 [00:00<00:00,  3.52it/s]

50
tensor([ 4, 46,  5, 44, 14, 10, 24, 21, 16, 19, 37, 49, 48, 43,  1,  8,  3, 15,
        47, 36, 26, 30, 40,  6, 34, 32, 41, 31, 13, 35, 27, 45,  2, 23, 17,  7,
        38,  9, 12,  0, 33, 25, 28, 22, 20, 11, 29, 42, 18, 39],
       device='cuda:0')





"plt.plot(loss_values)\nplt.xlabel('Epoch')\nplt.ylabel('Loss')\nplt.show()"

In [4]:
checkpoint = torch.load('mamba/best_checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

RuntimeError: Error(s) in loading state_dict for MambaFull:
	Missing key(s) in state_dict: "norm_f.weight", "norm_f.bias", "embedding.bias", "layers.0.norm.bias", "layers.1.norm.bias". 
	Unexpected key(s) in state_dict: "final_norm.weight". 
	size mismatch for embedding.weight: copying a param with shape torch.Size([7, 64]) from checkpoint, the shape in current model is torch.Size([64, 2]).
	size mismatch for layers.0.mixer.A_log: copying a param with shape torch.Size([128, 16]) from checkpoint, the shape in current model is torch.Size([128, 64]).
	size mismatch for layers.0.mixer.x_proj.weight: copying a param with shape torch.Size([36, 128]) from checkpoint, the shape in current model is torch.Size([132, 128]).
	size mismatch for layers.1.mixer.A_log: copying a param with shape torch.Size([128, 16]) from checkpoint, the shape in current model is torch.Size([128, 64]).
	size mismatch for layers.1.mixer.x_proj.weight: copying a param with shape torch.Size([36, 128]) from checkpoint, the shape in current model is torch.Size([132, 128]).
	size mismatch for output_head.weight: copying a param with shape torch.Size([7, 64]) from checkpoint, the shape in current model is torch.Size([50, 64]).

In [None]:
print(sum(p.numel() for p in model.parameters()))
inputs, targets = generate_data(batch_size=10000)
inputs = torch.from_numpy(inputs).long().to(device)
targets = torch.from_numpy(targets).long().to(device)
outputs = model(inputs)
outputs = nn.Softmax(dim=2)(outputs)
print(inputs[0])
print(outputs[0].argmax(dim=1))
print('accuracy:', (outputs.argmax(dim=2) == targets).float().mean())
print(targets[0])

66368
tensor([0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 2, 0, 3, 0, 5, 0, 0, 0, 1, 6, 6, 6, 6,
        6], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 3, 5,
        1], device='cuda:0')
accuracy: tensor(1.0000, device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 3, 5,
        1], device='cuda:0')
