In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

import PIL
from PIL import Image
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
import os
import sys

import torchvision.transforms as transforms
import torchvision.models as models

In [None]:
class MovieGenreClassifier(nn.Module):
    def __init__(self, nclass, nlabel):
        super(MovieGenreClassifier, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(nclass, 64),
            nn.ReLU(),
            nn.Linear(64, nlabel),
        )

    def forward(self, input):
        return self.main(input)

def extractObjectFeature(image, networks):
    pass
    
def extractLabel(image, id2genre, genresTable):
    genres = id2genre[image]
    genres = genres.split('|')
    labelVec = torch.zeros(1, 23)
    for genre in genres:
        if genre in genresTable:
            labelVec[0][genresTable[genre]] = 1
    
    return labelVec

In [None]:
loader = transforms.Compose([
    transforms.Scale((182, 268)),
    transforms.ToTensor()])  # transform it into a torch tensor

class MyDataset(Dataset):
    def __init__(self, root, csvfile, networks, transform=None):
        self.root = root
        self.transform = transform
        self.csvfile = open(csvfile, 'rb')
        self.networks = networks
        
        reader = csv.reader(self.csvfile)

        id2genre = {}

        for row in reader:
            if row[0] != "":
                id2genre[row[0] + ".jpg"] = row[4]
        
        self.csvfile = open(csvfile, 'rb')
        
        reader = csv.reader(self.csvfile)
        
        genres = {}
        for row in reader:
            genre = row[4].split('|')
            for ele in genre:
                if ele != '':
                    genres[ele] = genres.get(ele, 0) + 1

        for ele in list(genres):
            if (genres[ele] < 100):
                del genres[ele]        

        genresTable = {}

        count = 0
        for ele in list(genres):
            genresTable[ele] = count
            count += 1
                
        self.dataset = []
        self.labels = []
        
        count = 0
        for img in tqdm(os.listdir(self.root)):
            image = io.imread(os.path.join(self.root, img))
            count += 1

            self.dataset.append(image)
            self.labels.append(extractLabel(img, id2genre, genresTable))

    def __len__(self):
        return len(os.listdir(self.root))

    def __getitem__(self, idx):
        image = self.dataset[idx]
        label = self.labels[idx]
        
        image = Image.fromarray(image)
        
        if self.transform is not None:
            image = self.transform(image)

        image = Variable(loader(image))
        
        image = image.cuda()
        # fake batch dimension required to fit network's input dimensions
        image = image.unsqueeze(0)        
        
        feature = extractStyleFeature(image, self.networks)
        
        return feature.data, label
                               
trainset = MyDataset(root='/home/ubuntu/notebooks/dataset/train',
                     csvfile='/home/ubuntu/notebooks/Movie-Genre-Classification-from-Movie-Poster/Dataset/NewMovieGenre.csv', networks=networks)
valset = MyDataset(root='/home/ubuntu/notebooks/dataset/validation/',
                   csvfile='/home/ubuntu/notebooks/Movie-Genre-Classification-from-Movie-Poster/Dataset/NewMovieGenre.csv', networks=networks)