In [15]:
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torchvision import transforms, datasets

import matplotlib.pyplot as plt
from math import ceil

import numpy as np

from os.path import join

from matplotlib import pyplot as plt

from torchvision.models import inception_v3

from tqdm import tqdm

import pickle

In [2]:
IMG_SIZE = (299, 299)
datadir = 'C:\\associative_represenations_data\\'

# Data

In [3]:
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
train = datasets.CocoCaptions(
    join(datadir, 'train2017'), 
    annFile=join(datadir, 'captions_train2017.json'), 
    transform=preprocess)
val = datasets.CocoCaptions(
    join(datadir, 'val2017'),
    annFile=join(datadir, 'captions_val2017.json'),
    transform=preprocess)

loading annotations into memory...
Done (t=0.79s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [5]:
trainset = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
valset = torch.utils.data.DataLoader(val, batch_size=32, shuffle=True)

In [6]:
train

Dataset CocoCaptions
    Number of datapoints: 118287
    Root location: C:\associative_represenations_data\train2017
    StandardTransform
Transform: Compose(
               Resize(size=299, interpolation=PIL.Image.BILINEAR)
               CenterCrop(size=(299, 299))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

# Model

In [7]:
model = inception_v3(pretrained=True)
model.fc = nn.Identity()
model.eval()
model.to('cuda')

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, t

In [53]:
train_embeds = torch.tensor([])
train_capts = None

for images, captions in tqdm(trainset):
    
    with torch.no_grad():
        embeded_batch = model(images.cuda())
    train_embeds = torch.cat((train_embeds, embeded_batch.cpu()), 0)
    
    captions = np.array(captions).T
    if train_capts is None:
        train_capts = captions
    else:
        train_capts = np.vstack((train_capts, captions))
    
with open('tfeats.pkl', 'wb') as file_embeds:
    pickle.dump(train_embeds.numpy(), file_embeds)
    
with open('tcapts.pkl', 'wb') as file_capts:
    pickle.dump(train_capts, file_capts)

100%|███████████████████████████████████████████████████████████████████████| 3697/3697 [58:21<00:00,  1.06it/s]


In [54]:
val_embeds = torch.tensor([])
val_capts = None

for images, captions in tqdm(valset):
    
    with torch.no_grad():
        embeded_batch = model(images.cuda())
    val_embeds = torch.cat((val_embeds, embeded_batch.cpu()), 0)
    
    captions = np.array(captions).T
    if val_capts is None:
        val_capts = captions
    else:
        val_capts = np.vstack((val_capts, captions))
    
with open('vfeats.pkl', 'wb') as file_embeds:
    pickle.dump(val_embeds.numpy(), file_embeds)
    
with open('vcapts.pkl', 'wb') as file_capts:
    pickle.dump(val_capts, file_capts)

100%|█████████████████████████████████████████████████████████████████████████| 157/157 [01:34<00:00,  1.65it/s]


In [58]:
train_capts.shape

(118287, 5)

In [59]:
 train_capts

array([['A man is in a funny position during a tennis match',
        'A tennis player at the net after his play on the court. ',
        'A man near the net playing tennis with official looking on.',
        'a man by a tennis net getting ready to hit a ball',
        'A man is attempting to return the ball'],
       ['A white van is following an orange and white bus down the road. ',
        'A van following behind a bus in the street. ',
        'A white and orange bus driving down a city street.',
        'A van follows behind a bus on a rural road.',
        'A passenger bus that is driving down a street.'],
       ['A group of children sitting around each other.',
        'four children looking at each other one holding long object',
        'A girl holding a tube talking to another girl.',
        'Group of children sitting on a bench petting a dog.',
        'The children are grouped together waiting their turn.'],
       ...,
       ['the elephants are  all next to each other 