<a href="https://colab.research.google.com/github/FoxHound0x00/SharadaHTR/blob/test/SharadaHTR_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
!pip3 install imageio matplotlib opencv-python pandas pillow python-Levenshtein scikit-image scipy torch torchaudio torchvision torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!cp -r /content/drive/MyDrive/Sharada_files/output.tar.gz /content/
!tar -xf /content/output.tar.gz

In [5]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [6]:
import os
import shutil
import re

DIR = "output/"
DST_DIR = "temp_annot/"

os.makedirs(DST_DIR,exist_ok=True)
files = os.listdir(DIR)


for f in files:
    if f.endswith('.json'):
        json_file = os.path.join(DIR,f)
        img_file = os.path.join(DIR,os.path.splitext(f)[0]+'.png')
        if os.path.isfile(img_file):
            shutil.copy(img_file,DST_DIR)
            shutil.copy(json_file,DST_DIR)

In [7]:
import json
import random
import string
import numpy as np
from PIL import Image, ImageDraw, ImageOps

os.makedirs("temp_annot/",exist_ok=True)
os.makedirs("extracted_dir/",exist_ok=True)

src_dir =  "temp_annot/"
dest_dir = "extracted_dir/"

for file in os.listdir(src_dir):
    if not file.endswith('.json'):

        img_path = os.path.join(src_dir,file)
        json_path = os.path.join(src_dir,os.path.splitext(file)[0]+'.json')

        if os.path.exists(json_path):
            with open(json_path, 'r', encoding='utf-8') as f:
                annotation = json.load(f)

            image_height = annotation['imageHeight']
            image_width = annotation['imageWidth']
            image = Image.open(img_path)


            for shape in annotation['shapes']:
                shape_type = shape['shape_type']
                group_id = shape['group_id']
                label = shape['label']
                coordinates = shape['points']
                cropped_image = Image.new('RGBA', (image_width, image_height), (0, 0, 0, 0))
                draw = ImageDraw.Draw(cropped_image)
                random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=15))
                int_coordinates = [(int(point[0]), int(point[1])) for point in coordinates]
                draw.polygon(int_coordinates, fill=(255, 255, 255, 255))


                if shape_type == 'polygon':
                    mask = ImageOps.invert(cropped_image.convert('L'))

                    masked_image = Image.new("1", image.size)
                    masked_image.paste(image, (0, 0), mask=cropped_image)
                    bbox = masked_image.getbbox()

                    if bbox:
                        cropped_image = masked_image.crop(bbox)
                        rectangular_image = Image.new("1", (cropped_image.width, cropped_image.height), (255, 255, 255))
                        rectangular_image.paste(cropped_image, (0, 0), cropped_image)
                        new_ImageName = os.path.join(dest_dir, f"{random_string}.jpg")
                        new_LabelName = os.path.join(dest_dir, f"{random_string}.txt")
                        rectangular_image.save(new_ImageName, format='JPEG', quality=100)
                        open(new_LabelName, "w", encoding="utf-8").write(label)

                if shape_type == 'rectangle':
                    x_coordinates = [point[0] for point in int_coordinates]
                    y_coordinates = [point[1] for point in int_coordinates]
                    left = min(x_coordinates)
                    top = min(y_coordinates)
                    right = max(x_coordinates)
                    bottom = max(y_coordinates)
                    cropped_image = image.crop((left, top, right, bottom))
                    new_ImageName = os.path.join(dest_dir, f"{random_string}.jpg")
                    new_LabelName = os.path.join(dest_dir, f"{random_string}.txt")
                    cropped_image.save(new_ImageName, format='JPEG', quality=100)
                    open(new_LabelName, "w", encoding="utf-8").write(label)

# 1 (1-bit pixels, black and white, stored with one pixel per byte)
# L (8-bit pixels, grayscale)
# RGB (3x8-bit pixels, true color)
# RGBA (4x8-bit pixels, true color with transparency mask)

In [1]:
def levenshtein_distance(str1, str2):
    len_str1 = len(str1) + 1
    len_str2 = len(str2) + 1

    # Initialize a matrix to store distances
    matrix = [[0] * len_str2 for _ in range(len_str1)]

    # Initialize the first row and column
    for i in range(len_str1):
        matrix[i][0] = i
    for j in range(len_str2):
        matrix[0][j] = j

    # Calculate distances
    for i in range(1, len_str1):
        for j in range(1, len_str2):
            cost = 0 if str1[i - 1] == str2[j - 1] else 1
            matrix[i][j] = min(
                matrix[i - 1][j] + 1,  # Deletion
                matrix[i][j - 1] + 1,  # Insertion
                matrix[i - 1][j - 1] + cost,  # Substitution
            )

    return matrix[len_str1 - 1][len_str2 - 1]

## Dataset

In [2]:
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset
from PIL import Image
class SharadaDataset(Dataset):
    """Scripture dataset Class."""

    def __init__(self, txt_dir, img_dir, transform=None, char_dict=None):
        """
        Args:
            txt_dir (string): Path to the txt file with labels.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.txt_dir = txt_dir
        self.img_dir = img_dir
        self.transform = transform
        self.max_len = 0
        self.char_list = " -ँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ≈–|"
        if self.char_list is not None:
            chars = sorted(list(set(self.char_list)))
            self.char_dict = {c:i for i,c in enumerate(chars,1)}

        txt_files = os.listdir(self.txt_dir)
        self.txt_paths = [txt_file for txt_file in txt_files if txt_file.endswith('.txt')]
        img_files = os.listdir(self.img_dir)
        self.img_paths = [img_file for img_file in img_files if img_file.endswith('.jpg')]


    def __len__(self):
        return len(self.txt_paths)

    def __getitem__(self, idx):
        img_name = self.img_paths[idx]
        img_filepath = os.path.join(self.img_dir,img_name)
        try:
            image = Image.open(img_filepath)

        except OSError:
            image = np.random.randint(0, 255, size=(50, 100), dtype=np.uint8)

        txt_name = self.txt_paths[idx]
        txt_filepath = os.path.join(self.txt_dir,txt_name)
        try:
            with open(txt_filepath,'r') as file:
                label = file.read()

        except OSError:

            label = ""
        if len(label) > self.max_len:
            self.max_len = len(label)

        sample = {'image': image, 'label': label}
        # print(sample)
        if self.transform:
            sample = self.transform(sample)
        return sample

## Dataloader

In [3]:
# Dataloader Class
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
# from SharadaDS import SharadaDataset

class SharadaDataLoader(object):

    def __init__(self, ds, batch_size=(16, 16), validation_split=0.2,
                 shuffle=True, seed=42, device='cpu'):
        assert isinstance(ds, SharadaDataset)
        assert isinstance(batch_size, tuple)
        assert isinstance(validation_split, float)
        assert isinstance(shuffle, bool)
        assert isinstance(seed, int)
        assert isinstance(device, str)

        self.ds = ds
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.shuffle = shuffle
        self.seed = seed
        self.device = device

    def  __call__(self):

        dataset_size = len(self.ds)
        indices = list(range(dataset_size))
        split = int(np.floor(self.validation_split * dataset_size))

        if self.shuffle:
            np.random.seed(self.seed)
            np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]

        # Creating PT data samplers and loaders:
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(val_indices)

        # Dataloader
        train_loader = DataLoader(self.ds, batch_size=self.batch_size[0],
                                  sampler=train_sampler, collate_fn=self.collate_fn)
        validation_loader = DataLoader(self.ds, batch_size=self.batch_size[1],
                                       sampler=valid_sampler, collate_fn=self.collate_fn)

        return train_loader, validation_loader



    def collate_fn(self, batch):
        """Creates mini-batch tensors from the list of tuples (image, label).

        We should build custom collate_fn rather than using default collate_fn,
        because merging label tensor creates jagged array.
        Args:
            data: list of tuple (image, caption).
                - image: torch tensor of shape (1, 128, 32).
                - label: torch tensor of shape (?); variable length.
        Returns:
            images: torch tensor of shape (batch_size, chan_in, height, width).
            targets: torch tensor of shape (sum(target_lengths)).
            lengths: torch tensor; length of each target label.
        """

        # Sort a data list by caption length (descending order).
        #sample.sort(key=lambda x: len(x[1]), reverse=True)
        images, labels = [b.get('image') for b in batch], [b.get('label') for b in batch]

        # Merge images (from tuple of 3D tensor to 4D tensor).
        images = torch.stack(images, 0)

        # Merge captions (from tuple of 1D tensor to 2D tensor).
        lengths = [len(label) for label in labels]
        targets = torch.zeros(sum(lengths)).long()
        lengths = torch.tensor(lengths)
        for j, label in enumerate(labels):
            start = sum(lengths[:j])
            end = lengths[j]
            targets[start:start+end] = torch.tensor([self.ds.char_dict.get(letter) for letter in label]).long()

        if self.device == 'cpu':
            dev = torch.device('cpu')
        else:
            dev = torch.device('cuda')

        return images.to(dev), targets.to(dev), lengths.to(dev)


##Model

In [109]:
import torch
from torch import nn
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.downsample(identity)
        out = self.relu(out)
        return out

class CNN(nn.Module):
    def __init__(self, input_channels, hidden_size):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.resblocks = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 256, stride=2)
        )
        self.conv2 = nn.Conv2d(256, hidden_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.pool = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.resblocks(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        T, b, h = out.size() # T - time_steps
        # print(out.size())
        out = self.fc(out[:, -1, :])
        return out

class CRNN(nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, num_classes):
        super(CRNN, self).__init__()
        self.cnn = CNN(input_channels, hidden_size)
        self.rnn = BiLSTM(hidden_size, hidden_size, num_layers, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        # print("shape",x.shape)
        x = x.squeeze(2)
        # print("after squeeze", x.shape )
        x = x.permute(2, 0, 1)
        x = self.rnn(x)
        print(x.shape)
        return x

# Transforms

In [110]:
# import torch
# from torch import nn, optim
# import numpy as np
# import matplotlib.pyplot as plt
# import os
# import pickle
# import Levenshtein as leven
# from skimage.color import rgb2gray
# from skimage.transform import rotate
# import matplotlib.pyplot as plt



In [111]:
# transforms
# Transform and Data Augmentation
from skimage import transform, color, filters
import cv2 as cv
import numpy as np
import torch
from torchvision.transforms import Normalize
import torchvision.transforms.functional as F

class PadResize(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        self.w_f, self.h_f = self.output_size
        ratio_final = self.w_f / self.h_f
        self.w, self.h = image.size
        self.ratio_current = self.w / self.h

        # check if the original and final aspect ratios are the same within a margin
        if round(self.ratio_current, 2) != round(ratio_final, 2):
            # padding to preserve aspect ratio
            hp = int(self.w/ratio_final - self.h)
            wp = int(ratio_final * self.h - self.w)

            if hp > 0 and wp < 0:
                hp = hp // 2
                image = F.pad(image, (0, hp, 0, hp), 0, "constant")
                image = F.resize(image, [self.h_f, self.w_f])
            elif wp > 0 and hp < 0:
                wp = wp // 2
                image = F.pad(image, (wp, 0, wp, 0), 0, "constant")
                image = F.resize(image, [self.h_f, self.w_f])
        else:
            image = F.resize(image,[self.h_f, self.w_f])

        return {'image': image, 'label': label}

class Deskew(object):
    """Deskew handwriting samples"""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        try:
            threshold = filters.threshold_otsu(image)
        except ValueError:
            return {'image':image, 'label':label}

        binary = image.copy() < threshold

        # array of alpha values
        alphas = np.arange(-1, 1.1, 0.25)
        alpha_res = np.array([])
        alpha_params = []

        for a in alphas:
            alpha_sum = 0
            shift_x = np.max([-a*binary.shape[0], 0])
            M = np.array([[1, a, shift_x],
                          [0,1,0]], dtype=np.float64)
            img_size = (np.int(binary.shape[1] + np.ceil(np.abs(a*binary.shape[0]))), binary.shape[0])
            alpha_params.append((M, img_size))


            img_shear = cv.warpAffine(src=binary.astype(np.uint8),
                                      M=M, dsize=img_size,
                                      flags=cv.INTER_NEAREST)

            for i in range(0, img_shear.shape[1]):
                if not np.any(img_shear[:, i]):
                    continue

                h_alpha = np.sum(img_shear[:, i])
                fgr_pos = np.where(img_shear[:, i] == 1)
                delta_y_alpha = fgr_pos[0][-1] - fgr_pos[0][0] + 1

                if h_alpha == delta_y_alpha:
                    alpha_sum += h_alpha**2

            alpha_res = np.append(alpha_res, alpha_sum)

        best_M, best_size = alpha_params[alpha_res.argmax()]
        deskewed_img = cv.warpAffine(src=image, M=best_M, dsize=best_size,
                                      flags=cv.INTER_LINEAR,
                                      borderMode=cv.BORDER_CONSTANT,
                                      borderValue=255)

        return {'image':deskewed_img, 'label':label}

class toRGB(object):
    """Convert the ndarrys to RGB tensors.
       Required if using ImageNet pretrained Resnet."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        image = color.gray2rgb(image)

        return {'image': image, 'label': label}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __init__(self, rgb=True):
        assert isinstance(rgb, bool)
        self.rgb = rgb

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = F.to_tensor(image)
        return {'image': image, 'label': label}



class Normalize_Cust(object):
    """Normalise by channel mean and std"""

    def __init__(self, mean, std):
        self.mean = torch.tensor(mean, dtype=torch.float)
        self.std = torch.tensor(std, dtype=torch.float)
        self.norm = Normalize(mean, std)

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        return {'image': self.norm(image), 'label': label}

## Train

In [112]:
import torch
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from skimage.color import rgb2gray
from skimage.transform import rotate
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms

# from utils import *
# from dataset import SharadaDataset
# from dataloader import SharadaDataLoader
# from transforms import PadResize, Deskew, toRGB, ToTensor, Normalize_Cust

os.makedirs("chk_pts/", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)



dataset = SharadaDataset(txt_dir='/content/extracted_dir/',
                        img_dir='/content/extracted_dir/',
                        transform=Compose([
                            # Deslant(),
                            PadResize(output_size=(64,200)),
                            ToTensor(), # converted to Tensor
                            Normalize_Cust(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                        ]))

# dataset = DevDataset(images,
#                         labels,
#                         transform=Compose([
#                             # Deslant(),
#                             PadResize(output_size=(64,200)),
#                             ToTensor(), # converted to Tensor
#                             Normalize_Cust(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
#                         ]))

dl = SharadaDataLoader(dataset,
                       batch_size=(120,240),
                       validation_split=0.2,
                       shuffle=True,
                       seed=3407,
                       device=str(device))

crnn_model = CRNN(input_channels=3, hidden_size=512, num_layers=2, num_classes=len(dataset.char_dict) + 1).to(device)
optimizer = Adam(crnn_model.parameters(), lr=0.001)

ctc_loss = nn.CTCLoss(blank=0, reduction='mean')

cuda


In [113]:
print(crnn_model)

CRNN(
  (cnn): CNN(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (resblocks): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential()
      )
      (1): ResidualBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
  

In [114]:
from torchinfo import summary
summary(crnn_model, (100, 3, 64, 64))

torch.Size([16, 113])


Layer (type:depth-idx)                        Output Shape              Param #
CRNN                                          [16, 113]                 --
├─CNN: 1-1                                    [100, 512, 1, 16]         --
│    └─Conv2d: 2-1                            [100, 64, 64, 64]         1,728
│    └─BatchNorm2d: 2-2                       [100, 64, 64, 64]         128
│    └─ReLU: 2-3                              [100, 64, 64, 64]         --
│    └─Sequential: 2-4                        [100, 256, 16, 16]        --
│    │    └─ResidualBlock: 3-1                [100, 64, 64, 64]         73,984
│    │    └─ResidualBlock: 3-2                [100, 128, 32, 32]        230,144
│    │    └─ResidualBlock: 3-3                [100, 256, 16, 16]        919,040
│    └─Conv2d: 2-5                            [100, 512, 16, 16]        131,072
│    └─AdaptiveAvgPool2d: 2-6                 [100, 512, 1, 16]         --
├─BiLSTM: 1-2                                 [16, 113]                 

In [98]:
train_loader, val_loader = dl()

In [102]:
writer = SummaryWriter()
num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    crnn_model.train()  # Set the model to training mode
    total_loss = 0.0

    # Iterate over the training dataset
    for images, targets, lengths in train_loader:  # Assuming dl() returns train_loader
        # print("Here:",images)
        images = images.to(device)
        targets = targets.to(device)

        # Forward pass
        logits = crnn_model(images)
        print(logits.shape)

        # Calculate the CTC loss
        loss = ctc_loss(logits, targets, lengths, lengths)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Calculate average training loss for the epoch
    avg_loss = total_loss / len(train_loader)

    # Log the training loss to Tensorboard
    writer.add_scalar('Loss/Train', avg_loss, epoch)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')

    # Validation
    if (epoch + 1) % 1 == 0:  # You can adjust the frequency of validation
        crnn_model.eval()  # Set the model to evaluation mode
        val_loss = 0.0

        # Iterate over the validation dataset
        with torch.no_grad():
            for val_images, val_targets, val_lengths in val_loader:  # Assuming dl() returns validation_loader
                val_images = val_images.to(device)
                val_targets = val_targets.to(device)

                # Forward pass
                val_logits = crnn_model(val_images)

                # Calculate the CTC loss
                val_loss += ctc_loss(val_logits, val_targets, val_lengths, val_lengths).item()

                _, predicted_labels = torch.max(val_logits, 2)
                predicted_labels = ["".join([dataset.char_list[c] for c in row if c != 0]) for row in predicted_labels.cpu().numpy()]

                for pred, target in zip(predicted_labels, val_targets.cpu().numpy()):
                    distance = levenshtein_distance(pred, "".join([dataset.char_list[c] for c in target if c != 0]))

                    writer.add_scalar('LevenshteinDistance/Validation', distance, epoch)

        avg_val_loss = val_loss / len(val_loader)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        crnn_model.train()

        print(f'Validation Loss: {avg_val_loss:.4f}')

# Save the trained model
torch.save(crnn_model.state_dict(), 'chk_pts/crnn_model.pth')

# Close Tensorboard writer
writer.close()

shape torch.Size([120, 512, 1, 16])
after squeeze torch.Size([120, 512, 16])
torch.Size([16, 120, 1024])
torch.Size([16, 113])


RuntimeError: input_lengths must be of size batch_size

## Other Model Sample

In [None]:
import torch
import torch.nn as nn

class IEncoder(nn.Module):
    def __init__(self):
        super(IEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class TEncoder(nn.Module):
    def __init__(self, input_size, latent_size):
        super(TEncoder, self).__init__()

        self.encoder = nn.LSTM(input_size, latent_size, batch_first=True)

    def forward(self, x):
        _, (hidden, _) = self.encoder(x)
        return hidden.squeeze(0)

class LatentProjectionSpace(nn.Module):
    def __init__(self, latent_size):
        super(LatentProjectionSpace, self).__init__()

        self.projection = nn.Linear(64 + latent_size, latent_size)

    def forward(self, img_emb, txt_emb):
        combined_latent = torch.cat((img_emb.view(img_emb.size(0), -1), txt_emb), dim=1)
        return self.projection(combined_latent)

class IDecoder(nn.Module):
    def __init__(self):
        super(IDecoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

class TDecoder(nn.Module):
    def __init__(self, latent_size, output_size):
        super(TDecoder, self).__init__()

        self.decoder = nn.LSTM(latent_size, output_size, batch_first=True)

    def forward(self, x):
        output, _ = self.decoder(x.unsqueeze(1))
        return output.squeeze(1)

class ImgTextAutoencoder(nn.Module):
    def __init__(self, img_input_size, txt_input_size, latent_size):
        super(ImgTextAutoencoder, self).__init__()

        self.image_encoder = IEncoder()
        self.text_encoder = TEncoder(txt_input_size, latent_size)
        self.latent_projection = LatentProjectionSpace(latent_size)
        self.image_decoder = IDecoder()
        self.text_decoder = TDecoder(latent_size, txt_input_size)

    def forward(self, img_input, txt_input):
        img_emb = self.image_encoder(img_input)
        txt_emb = self.text_encoder(txt_input)
        latent = self.latent_projection(img_emb, txt_emb)
        rec_img = self.image_decoder(img_emb)
        rec_txt = self.text_decoder(latent)
        return rec_img, rec_txt

img_input_size = (3, 64, 64)
txt_input_size = 300
latent_size = 100

img_text_autoencoder = ImgTextAutoencoder(img_input_size, txt_input_size, latent_size)

print(img_text_autoencoder)
