In [1]:
import os
import torch

import dataloaders
import models
from steps.util import *

In [2]:
data_train = "/teamwork/t40511_asr/c/PlacesAudio400k/PlacesAudio_400k_distro/metadata/train1sample.json"
train_loader = torch.utils.data.DataLoader(
    dataloaders.ImageCaptionDataset(data_train, audio_conf={'target_length': 2048}, image_conf={'center_crop': True}),
    batch_size=1, shuffle=True, num_workers=8, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    dataloaders.ImageCaptionDataset(data_train, audio_conf={'target_length': 2048}, image_conf={'center_crop': True}),
    batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

audio_model = models.ConvX3AudioNet(input_length=2048)
image_model = models.VGG16()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(True)
dot_loss = DotLoss()
# Initialize all of the statistics we want to keep track of
data_time = AverageMeter()
loss_meter = AverageMeter()
progress = []
best_epoch, best_acc = 0, -np.inf
global_step, epoch = 0, 0

if not isinstance(audio_model, torch.nn.DataParallel):
    audio_model = nn.DataParallel(audio_model)
if not isinstance(image_model, torch.nn.DataParallel):
    image_model = nn.DataParallel(image_model)

In [3]:
audio_model = audio_model.to(device)
image_model = image_model.to(device)
# Set up the optimizer
audio_trainables = [p for p in audio_model.parameters() if p.requires_grad]
image_trainables = [p for p in image_model.parameters() if p.requires_grad]
trainables = audio_trainables + image_trainables
optimizer = torch.optim.SGD(trainables, 1e-4,
                            momentum=0.9,
                            weight_decay=5e-7)

In [4]:
print("current #steps=%s, #epochs=%s" % (global_step, epoch))
print("start training...")

audio_model.train()
image_model.train()

for i, (image_input, audio_input, nframes) in enumerate(train_loader):
    # measure data loading time
    B = audio_input.size(0)

    audio_input = audio_input.to(device)
    image_input = image_input.to(device)

    optimizer.zero_grad()

    audio_output = audio_model(audio_input)
    image_output = image_model(image_input)
    print(audio_output.size())
    pooling_ratio = round(audio_input.size(-1) / audio_output.size(-1))
    nframes.div_(pooling_ratio)

    loss = dot_loss(image_output, audio_output)

    loss.backward()
    optimizer.step()

    # record loss
    loss_meter.update(loss.item(), B)

    print('Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
          'Loss total {loss_meter.val:.4f} ({loss_meter.avg:.4f})'.format(
            data_time=data_time, loss_meter=loss_meter), flush=True)


current #steps=0, #epochs=0
start training...
torch.Size([1024])


AssertionError: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_time = AverageMeter()
if not isinstance(audio_model, torch.nn.DataParallel):
    audio_model = nn.DataParallel(audio_model)
if not isinstance(image_model, torch.nn.DataParallel):
    image_model = nn.DataParallel(image_model)
audio_model = audio_model.to(device)
image_model = image_model.to(device)
# switch to evaluate mode
image_model.eval()
audio_model.eval()

end = time.time()
N_examples = len(val_loader.dataset)
I_embeddings = []
A_embeddings = []
frame_counts = []
with torch.no_grad():
    for i, (image_input, audio_input, nframes) in enumerate(val_loader):
        image_input = image_input.to(device)
        audio_input = audio_input.to(device)

        # compute output
        image_output = image_model(image_input)
        audio_output = audio_model(audio_input)

        image_output = image_output.to('cpu').detach()
        audio_output = audio_output.to('cpu').detach()

        I_embeddings.append(image_output)
        A_embeddings.append(audio_output)

        pooling_ratio = round(audio_input.size(-1) / audio_output.size(-1))
        nframes.div_(pooling_ratio)

        frame_counts.append(nframes.cpu())

        batch_time.update(time.time() - end)
        end = time.time()

    image_output = torch.cat(I_embeddings)
    audio_output = torch.cat(A_embeddings)
    nframes = torch.cat(frame_counts)

    recalls = calc_recalls(image_output, audio_output)