In [21]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import types
import torch
import TextNet
import TeacherNet
import ImageNet
import torch.optim as optim
import time
import sys
#import objgraph
import gc
import copy
import utils

In [22]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [23]:
def train_model(teacher_model, img_model, txt_model, dataloaders,
                criterion, optimizer, num_epochs=50000):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(teacher_model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

#         print('after a epoch')
#         objgraph.show_growth()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                teacher_model.train()  # Set model to training mode
                img_model.train()
                txt_model.train()
                print('Train phase')
            else:
                teacher_model.eval()  # Set model to evaluate mode
                img_model.eval()
                txt_model.eval()
                print('Val phase')

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            num_batch = 0
            for sample_batched in dataloaders[phase]:
                num_batch += 1

                img = sample_batched['image'].float().to(device)
                embeds = sample_batched['embeds'].float().to(device)

#                 print('initial at  a batch')
#                 objgraph.show_growth()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    #  In train mode we calculate the loss by summing the final output and the auxiliary output
                    #  but in testing we only consider the final output.

                    print('%s: [%s] forwarding image and embeddings...' % (str(epoch), str(num_batch)))
                    img_reprets = teacher_model.forward(img_model.forward(img))
                    txt_reprets = teacher_model.forward(txt_model.forward(embeds))

                    loss = criterion(img_reprets, txt_reprets)

                    preds = teacher_model.predict(img_reprets, txt_reprets)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        print('%s: [%s] backward and optimize...' % (str(epoch), str(num_batch)))
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * batch_size
                running_corrects += sum([(i == preds[i]) + 0 for i in range(len(preds))])

                print('after a batch')
                objgraph.show_growth()

                # release memory
                del img, embeds, img_reprets, txt_reprets, preds

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if epoch % 100 == 0:
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(teacher_model.state_dict())
                if phase == 'val':
                    val_acc_history.append(epoch_acc)

        gc.collect()
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    teacher_model.load_state_dict(best_model_wts)
    return teacher_model, val_acc_history

In [24]:
train_img = torch.load("../hci-intermodal-reasoning/cached_data/train_img")
train_cap = torch.load("../hci-intermodal-reasoning/cached_data/train_cap")
train_mask = torch.load("../hci-intermodal-reasoning/cached_data/train_mask")

val_img = torch.load("../hci-intermodal-reasoning/cached_data/val_img")
val_cap = torch.load("../hci-intermodal-reasoning/cached_data/val_cap")
val_mask = torch.load("../hci-intermodal-reasoning/cached_data/val_mask")

print("Loaded train data", train_img.size(), train_cap.size(), train_mask.size())
print("Loaded val data", val_img.size(), val_cap.size(), val_mask.size())


Loaded train data torch.Size([10000, 3, 224, 224]) torch.Size([10000, 52]) torch.Size([10000, 52])
Loaded val data torch.Size([5000, 3, 224, 224]) torch.Size([5000, 43]) torch.Size([5000, 43])


In [25]:
DELTA = 0.002
BATCH_SIZE = 8
NB_EPOCHS = 10


In [26]:
train_data = TensorDataset(train_img, train_cap, train_mask)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE, num_workers=2)
valid_data = TensorDataset(val_img, val_cap, val_mask)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=BATCH_SIZE * 2, num_workers=2)

device = "cuda:1"
text_net = TextNet.Text_Net(device)
vision_net = ImageNet.Image_Net(device)
teacher_net = TeacherNet.Teacher_Net()
ranking_loss = TeacherNet.RankingLossFunc(DELTA)
teacher_net.to(device)
ranking_loss.to(device)

RankingLossFunc()

In [18]:
model_name = "alexnet"
feature_extract = True
vision_net = vision_net.initialize_model(model_name, feature_extract, use_pretrained=True)

In [19]:
print("Params to learn:")
params_to_update_share = []
params_to_update_img = vision_net.parameters()
params_to_update_txt = []


Params to learn:


In [20]:
if __name__ == "__main__":
    
    print("Params to learn:")
    params_to_update_share = []
    params_to_update_img = vision_net.parameters()
    params_to_update_txt = []
    
    print(teacher_net)
    print(vision_net)
    print(text_net)
    
    for name, param in teacher_net.named_parameters():
        if param.requires_grad is True:
            params_to_update_share.append(param)
            print("\t", name)

    if feature_extract:
        params_to_update_img = []
        for name, param in vision_net.named_parameters():
            if param.requires_grad is True:
                params_to_update_img.append(param)
                print("\t", name)
    else:
        for name, param in vision_net.named_parameters():
            if param.requires_grad is True:
                print("\t", name)

    for name, param in text_net.named_parameters():
        if param.requires_grad is True:
            params_to_update_txt.append(param)
            print("\t", name)
            
    
    params_to_update = list(params_to_update_share) + list(params_to_update_img) + list(params_to_update_txt)
    optimizer = optim.Adam(params_to_update, lr=0.0001)


Params to learn:
Teacher_Net(
  (linear1): Linear(in_features=9216, out_features=4096, bias=True)
  (linear2): Linear(in_features=4096, out_features=4096, bias=True)
  (linear3): Linear(in_features=4096, out_features=1000, bias=True)
)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, s