In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import random_split
from nscl.datasets.clevr_dataset import build_clevr_dataset, build_clevr_dataloader, CLEVRDataset
from nscl.datasets.clevr_definition import CLEVRDefinition, QuestionTypes
from nscl.models.nscl_module import NSCLModule, ReasoningModule
from nscl.models.executor.program_executor import ProgramExecutor
from tqdm.autonotebook import tqdm
import random
from nscl.models.loss.nscl_loss import SceneParsingLoss, QALoss

In [None]:
# Dataset
train_img_root = '/data/CLEVR_v1.0/images/train'
train_scene_json = '/data/CLEVR_v1.0/scenes/train/scenes.json'
train_question_json = '/data/CLEVR_v1.0/questions/CLEVR_train_questions.json'

val_img_root = '/data/CLEVR_v1.0/images/val'
val_scene_json = '/data/CLEVR_v1.0/scenes/val/scenes.json'
val_question_json = '/data/CLEVR_v1.0/questions/CLEVR_val_questions.json'

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = NSCLModule(CLEVRDefinition.attribute_concept_map)

try:
    model.load_state_dict(torch.load('./nscl.weights'))
except:
    print('Cannot load model')

model = model.to(device)

optimiser = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Epoch, max_program_size, max_scene_size
curriculum_strategies = [
  # (10, 5, 3),
  (10, 10, 3),
  (10, 10, 5),
  (10, 20, 5),
  (10, 30, 7),
]

In [None]:
# Train
model.train()
scene_parsing_loss = SceneParsingLoss(reduction='sum')
qa_loss = QALoss(reduction='sum')
for epoch, max_program_size, max_scene_size in curriculum_strategies:
  print(f'Curriculum strategy: {max_program_size}, {max_scene_size}')
  dataset = build_clevr_dataset(train_img_root, train_scene_json, train_question_json, max_program_size=max_program_size, max_scene_size=max_scene_size)
  dataloader = build_clevr_dataloader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False)

  for i in range(epoch):
    epoch_loss = 0
    if i > 0 and i % 5 == 0:
        torch.save(model.state_dict(), './nscl.weights')

    with tqdm(total=len(dataloader), desc='train') as t:
      for images, questions, scenes in dataloader:
        optimiser.zero_grad()
        object_annotations, answers = model(images.to(device), questions, scenes)
        scene_loss = scene_parsing_loss(object_annotations, scenes)
        q_loss = qa_loss(questions, answers)
        total_loss = scene_loss + q_loss
        total_loss.backward()
        optimiser.step()
        epoch_loss += q_loss.item()
        t.set_postfix(batch_loss='scene_parsing_loss:{:05.3f},qa_loss:{:05.3f}'.format(scene_loss.item(), q_loss.item()))
        t.update()
      t.set_postfix(epoch_qa_loss='{:05.3f}'.format(epoch_loss))
      t.update()

In [None]:
# Validation set
val_dataset = build_clevr_dataset(val_img_root, val_scene_json, val_question_json, max_program_size=5, max_scene_size=3)
val_loader = build_clevr_dataloader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False)

In [None]:
model.eval()
correct = 0
total = len(val_dataset)

with tqdm(total=len(val_loader), desc='validation') as t:
  for images, questions, scenes in val_loader:
    _, results = model(images.to(device), questions, scenes)
    for i, q in enumerate(questions):
      if q.question_type == QuestionTypes.BOOLEAN:
          true_answer = 0 if q.answer == 'yes' else 1
          predicted_answer = torch.argmax(results[i]).item()
      elif q.question_type == QuestionTypes.COUNT:
          true_answer = round(questions[i].answer_tensor.item())
          predicted_answer = round(results[i].item())
      else:
          true_answer = torch.argmax(questions[i].answer_tensor).item()
          predicted_answer = torch.argmax(results[i]).item()
      
      is_answer_correct = (true_answer == predicted_answer)
      if random.random() > 0.99:
        print(f'{q.raw_question}|{true_answer}({questions[i].answer})|{predicted_answer}({results[i].data.cpu().detach().numpy()})|{is_answer_correct}')
      if is_answer_correct:
        correct += 1
  
    t.set_postfix(num_correct=f'{correct}')
    t.update()

  print(f'Accuracy : {correct}/{total}')