## Random Task Generation
- Here we generate each task randomly for every training iteration

In [None]:

# Training configurations
epochs = 10

criterion = nn.CrossEntropyLoss()

lr = 1e-3
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

mask = generate_causal_mask(MAX_FRAMES).to(device)

# torch.manual_seed(0)
# Generates a random location comparison task
def gen_comp_loc(lm_embedder, img_embedder):
    # f1 = random.randint(0, 4)
    # f2 = random.randint(f1+1,5)

    # task = CompareLocTemporal(whens=['last'+str(f1),'last'+str(f2)])

    task = CompareLocTemporal(whens=['last0','last1'])

    frame_info = ig.FrameInfo(task, task.generate_objset())
    compo_info = ig.TaskInfoCompo(task, frame_info)
    objset = compo_info.frame_info.objset

    frames = []
    for i, (epoch, frame) in enumerate(zip(sg.render(objset, 224), compo_info.frame_info)):
        if not any('ending' in description for description in frame.description):
            sg.add_fixation_cue(epoch)
        img = np.rollaxis(np.array(Image.fromarray(epoch, 'RGB'), dtype=np.float32),2,0)
        frames.append(img)

    frames = torch.tensor(frames)

    frames = img_embedder(frames, vit_encoder).unsqueeze(0)

    _, compo_example, _ = compo_info.get_examples()

    instruction = compo_example['instruction']

    instruction = tokenizer(instruction, padding=True, truncation=True, return_tensors='pt')

    instruction = lm_embedder(instruction, lm_encoder).unsqueeze(0)

    actions = compo_example['answers']

    action_map = {'true': 0, 'false': 1, 'null': 2}

    target_actions = []
    for action in actions:
        target_actions.append(action_map[action])

    return instruction, frames, torch.tensor(target_actions, dtype=torch.float32).unsqueeze(0)

In [None]:
# Training and validation loop using random generation

n_tasks = 0

# Store the average loss after each epoch
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}

print("starting")
for epoch in range(epochs):
    print(f"epoch={epoch}")

    # Epoch stat trackers
    epoch_loss = 0
    epoch_correct = 0
    epoch_count = 0

    for i in range(n_tasks):

        # Inputs and Targets
        instruction, frames, targets = gen_comp_loc(lm_embedder, img_embedder)

        # print(frames.shape)
        # print(instruction.shape)
        # print(len(targets))

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        predictions = predictions[:,pad_indexes]

        print(predictions.shape)
        print(predictions.permute(0, 2, 1).shape)
        print(targets.shape)

        # Get Loss by permuting the predictions into correct shape (batch_size, n_classes, seq_len)
        loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track stats
        # torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None)
        correct = predictions.argmax(dim=-1) == targets
        acc = correct.sum().item() / correct.size(0)
        epoch_correct += correct.sum().item()
        epoch_count += correct.size(0)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    # Validate on validation set every 5 epochs
    if (epoch+1) % 2 == 0:
        # Turn off gradient calcs
        with torch.no_grad():
            val_epoch_loss = 0
            val_epoch_correct = 0
            val_epoch_count = 0

            for idx, batch in enumerate(iter(val_dataloader)):
                # Inputs and Targets
                instruction = batch['instruction']
                frames = batch['frames']
                targets = batch['actions']
                
                # Frame Padding
                padding_mask = generate_pad_mask(batch=frames)
                pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

                # Get predictions
                predictions = model(instruction, frames, mask, padding_mask)
                predictions = predictions[:,pad_indexes]

                # Get Loss
                loss = criterion(predictions.permute(0, 2, 1), targets.long())

                correct = predictions.argmax(dim=-1) == targets
                acc = correct.sum().item() / correct.size(0)

                val_epoch_correct += correct.sum().item()
                val_epoch_count += correct.size(0)
                val_epoch_loss += loss.item()

        # Track loss and acc ever 5 epochs
        avg_train_loss = epoch_loss / len(train_dataloader)
        avg_val_loss = val_epoch_loss / len(val_dataloader)


        all_loss['val_loss'].append(avg_val_loss)
        all_acc['val_acc'].append(val_epoch_correct / val_epoch_count)

        all_acc['train_acc'].append(epoch_correct / epoch_count)
        all_loss['train_loss'].append(avg_train_loss)