In [1]:
from itertools import chain

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torchvision.models as models

In [6]:
class ImageEncoder(nn.Module):
    def __init__(self, in_channels=1280, out_channels=256, kernel_size=(3,3)):
        super(ImageEncoder, self).__init__()
        
        self.mobilenet = models.mobilenet_v2()
        self.backbone = self.mobilenet.features
        self.model = nn.Sequential(
            self.backbone,
            nn.Conv2d(in_channels=1280, out_channels=256, kernel_size=(3,3))
        )
    
    def forward(self, input_):
        out = self.model(input_)
        out = torch.reshape(out, (2, 1, 256, -1))
        out = out.permute(0,1,3,2)
        
        return out

In [7]:
model = ImageEncoder()

In [8]:
num_params = sum(p.numel() for p in model.parameters())

print(f"Trainable parameters: {num_params:,}")

Trainable parameters: 6,454,248


In [10]:
img = torch.rand(2,3,224,224)

In [11]:
out = model(img)

In [12]:
out.shape

torch.Size([2, 1, 25, 256])

In [None]:
text =  torch.rand(2, 1, 16, 512)
image = torch.rand(2, 1, 16, 256)

In [None]:
z = torch.cat((text, image), axis = 3)

In [None]:
z.shape