In [1]:
# Class Dataset
from torch.utils.data import Dataset
from label_to_index import map_classes

class ProductDataset(Dataset):
    def __init__(self, products, max_text_length, img_transform=None, txt_transform=None):
        self.products = products
        self.img_transform = img_transform
        self.txt_transform = txt_transform
        self.class_to_idx = map_classes()
        self.max_text_length = max_text_length
    
    def __len__(self):
        return len(self.products)
    
    def __getitem__(self, idx):
        product = self.products[idx]
        img, text = product['image'], product['description']
        img = img.convert('RGB')
        text = self.pad(text)
        if self.img_transform:
            img = self.img_transform(img)
        if self.txt_transform:
            text = self.txt_transfrom(text)
        return (img, text), self.class_to_idx[product['label']]
    
    def pad(self, desc):
        while len(desc) != self.max_text_length:
            desc = desc + ' '
        return desc
        
            
        
        

In [2]:
# From - https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

from torchvision import transforms

In [3]:
# Load Augmented dataset
import pickle

from random import shuffle

augmented_data = pickle.load(open('augmented_data.p', 'rb'))
shuffle(augmented_data)

# find text with max len
max_len = float('-inf')
for product in augmented_data:
    max_len = max(len(product['description']), max_len)
    



In [4]:
print(max_len)

2379


In [5]:
len(augmented_data)

1100

In [6]:
len(augmented_data[800:])

300

In [7]:


trfm = transforms.Compose([
                        transforms.Resize((64, 64)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.5], std=[0.224])
        ])


train_dataset = augmented_data[:800]
val_dataset = augmented_data[800:]

train_dataset = ProductDataset(train_dataset, max_len, img_transform=trfm)
val_dataset = ProductDataset(val_dataset, max_len, img_transform=trfm)

In [8]:
for product in train_dataset:
    print(product)
    break

((tensor([[[-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         ...,
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321]],

        [[-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         ...,
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321]],

        [[-2.2321, -2.2321, -2.2321,  ..., -2.2321, -2.2321, -2.2321],
         [-2.2321, -2.2321, -2.2321,  ..., 

In [9]:
with open('dataset.p', 'wb') as f:
    pickle.dump((train_dataset, val_dataset), f)

In [10]:
# Define the CNN Model 
# Source = https://github.com/meghanabhange/FashionMNIST-3-Layer-CNN/blob/master/Fashion.py
import torch
import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        #Conv 1
        #Convolutio
        self.cnn1= nn.Conv2d(in_channels=3, out_channels= 16, kernel_size=5, stride=1, padding=2)
        #Activation
        self.relu1 = nn.ReLU()
        self.conv1_bn = nn.BatchNorm2d(16)
        #MaxPool1
        self.MaxPool1 = nn.MaxPool2d(kernel_size=2)
        
        #Conv 2
        self.cnn2 = nn.Conv2d(in_channels= 16, out_channels= 32, kernel_size=5,stride=1, padding=2)
        self.relu2 = nn.ReLU()
        self.conv2_bn = nn.BatchNorm2d(32)
        
        self.MaxPool2 = nn.MaxPool2d(kernel_size=2)
        
        self.cnn3 = nn.Conv2d(in_channels= 32, out_channels= 64, kernel_size=5,stride=1, padding=2)
        self.relu3 = nn.ReLU()
        self.conv3_bn = nn.BatchNorm2d(64)
        
        self.MaxPool3 = nn.MaxPool2d(kernel_size=2)
        
        self.fc1 = nn.Linear(4096, 10)
    
    
    def forward(self, x):
        
        out = self.cnn1(x)
        out= self.relu1(out)
        our = self.conv1_bn(out)
        
        out= self.MaxPool1(out)
        
        out= self.cnn2(out)
        out = self.relu2(out)
        out = self.conv2_bn(out)
        
        out = self.MaxPool2(out)
        
        out= self.cnn3(out)
        out = self.relu3(out)
        out = self.conv3_bn(out)
        
        out = self.MaxPool3(out)
        
        
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        
        return out

In [11]:
img_model = CNNModel()
# img_model.load_state_dict(torch.load('cnn.pkl'))

In [12]:
params = list(img_model.parameters())
len(params)

14

In [13]:
for i, p in enumerate(params):
    print(i, p.shape, p.requires_grad)

0 torch.Size([16, 3, 5, 5]) True
1 torch.Size([16]) True
2 torch.Size([16]) True
3 torch.Size([16]) True
4 torch.Size([32, 16, 5, 5]) True
5 torch.Size([32]) True
6 torch.Size([32]) True
7 torch.Size([32]) True
8 torch.Size([64, 32, 5, 5]) True
9 torch.Size([64]) True
10 torch.Size([64]) True
11 torch.Size([64]) True
12 torch.Size([10, 4096]) True
13 torch.Size([10]) True


In [14]:
# for i, p in enumerate(img_model.parameters()):
#     if i != 8 and i != 13:
#         p.requires_grad = False

In [15]:
# for i, p in enumerate(params):
#     print(i, p.shape, p.requires_grad)

In [16]:
# Define Text Model
# Source https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/3%20-%20Faster%20Sentiment%20Analysis.ipynb

from torchtext import data, datasets

def generate_bigrams(x):
    n_grams = set(zip(*[x[i:] for i in range(2)]))
    for n_gram in n_grams:
        x.append(' '.join(n_gram))
    return x

text = data.Field(tokenize='spacy', preprocessing=generate_bigrams)

In [17]:
train_text_data = [t[0][1] for t in train_dataset]
len(train_text_data)

800

In [18]:
# Total number of words in train_text_data (approx.)
from collections import Counter
words = Counter()
for d in train_text_data:
    for word in d.split():
        words[word] += 1

In [19]:
len(words)

2707

In [20]:
words.most_common()

[('the', 1307),
 ("Women's", 774),
 ('and', 773),
 ('of', 666),
 ('at', 550),
 ('see', 470),
 ('Shop', 462),
 ('selection', 456),
 ('entire', 455),
 ('J.Crew', 451),
 ('Free', 445),
 ('Shipping', 445),
 ('Available.', 445),
 ('a', 445),
 ('In', 283),
 ('with', 270),
 ('\\-', 223),
 ('to', 182),
 ('for', 168),
 ('in', 166),
 ('our', 166),
 ('from', 114),
 ('is', 103),
 ('size', 84),
 ('Import.', 83),
 ('on', 81),
 ('+', 81),
 ('Top', 79),
 ('an', 78),
 ('Swimwear.', 77),
 ('Cotton', 77),
 ('100%', 72),
 ('Tall', 71),
 ('this', 70),
 ("Girls'", 66),
 ('Dress', 66),
 ('your', 63),
 ('it', 62),
 ('Bags.', 61),
 ('Dresses.', 61),
 ('skirt', 55),
 ('by', 54),
 ('de', 54),
 ('flattering', 51),
 ('waist', 49),
 ('The', 49),
 ('Imported', 48),
 ('stretch', 47),
 ('Sits', 47),
 ('are', 46),
 ("Men's", 46),
 ('leg', 45),
 ('perfect', 43),
 ('Leather', 43),
 ('fit.', 43),
 ('back', 42),
 ('matching', 41),
 ('cm', 41),
 ('Shirts.', 40),
 ('ruffled', 40),
 ('Pant', 40),
 ('A', 40),
 ('**Content', 40

In [21]:
text.build_vocab(train_text_data, vectors='glove.6B.100d', unk_init=torch.Tensor.normal_)

In [22]:
PAD_IDX = text.vocab.stoi[text.pad_token]
MAX_LEN = max_len

def convert_to_tensor(d):
    t = []
    for word in d.split():
        t.append(text.vocab.stoi[word])
    import pdb
    pdb.set_trace()
    while len(t) != MAX_LEN:
        t.append(PAD_IDX)
        print('infi loop')
    return torch.Tensor(t)

In [23]:
class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim, output_dim)
        
    def forward(self, text):
        
        #text = [sent len, batch size]
        
        embedded = self.embedding(text)
                
        #embedded = [sent len, batch size, emb dim]
        
        embedded = embedded.permute(1, 0, 2)
        
        #embedded = [batch size, sent len, emb dim]
        
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) 
        
        #pooled = [batch size, embedding_dim]
                
        return self.fc(pooled)

In [24]:
INPUT_DIM = len(text.vocab)
EMBEDDING_DIM = 100
OUTPUT_DIM = 4
PAD_IDX = text.vocab.stoi[text.pad_token]

text_model = FastText(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)

In [None]:
from torch.utils.data import DataLoader

batch_size = 20

img_train_dataset = []
txt_train_dataset = []

for t in train_dataset:
    img_train_dataset.append((t[0][0], t[1]))
    txt_train_dataset.append((convert_to_tensor(t[0][1]), t[1]))

> <ipython-input-22-702b772bfbe5>(10)convert_to_tensor()
-> while len(t) != MAX_LEN:
(Pdb) t
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
(Pdb) d
*** Newest frame
(Pdb) (d)
"Free Shipping Available.\nShop the Women's Straw Market Tote at J.Crew and see the entire selection of Women's Bags.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         

In [None]:
train_img_loader = DataLoader(img_train_dataset, batch_size)
train_txt_loader = DataLoader(txt_train_dataset, batch_size)

In [None]:
pretrained_embeddings = text.vocab.vectors

text_model.embedding.weight.data.copy_(pretrained_embeddings)

In [None]:
UNK_IDX = text.vocab.stoi[text.unk_token]

text_model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
text_model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [None]:
# Overall Model
IMG_MODEL_OUTPUTS = 10
TXT_MODEL_OUTPUTS = 4
NUM_CLASSES = 11

class OverallModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_model = img_model
        self.text_model = text_model
        self.fc = nn.Linear(IMG_MODEL_OUTPUTS + TXT_MODEL_OUTPUTS, NUM_CLASSES)
    
    def forward(self, obj):
        img, text = obj
        img_out = self.img_model(img)
        txt_out = self.text_model(text)
        return self.fc(torch.cat((img_out, txt_out),0))
        

In [None]:
model = OverallModel()

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters())

In [None]:
criterion = nn.BCEWithLogitsLoss()


In [None]:
def accuracy(preds, labels):
    pass

def train(model, txt_loader, img_loader, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    img_iter = iter(img_loader)
    txt_iter = iter(txt_loader)
    while True:
        try:
            img_batch, _ = next(img_iter)
            txt_batch, labels = next(txt_iter)
            data = (img_batch, txt_batch)
        except StopIteration:
            break
        optimizer.zero_grad()
        predictions = model(data).squeeze(1)
        loss = criterion(predictions, label)
        acc = accuracy(predictions, label)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
for epoch in range(4):
    tl, ta = train(model, train_txt_loader, train_img_loader, optimizer, criterion)

In [None]:
txt_batch