In [187]:
import pandas as pd, torch
from PIL import Image
from sklearn.model_selection import train_test_split

In [30]:
df = pd.read_csv('data/movie_db.csv').dropna()

In [189]:
train, test = train_test_split(df, test_size=0.3)

## Process Genre Tags

In [32]:
genre = df['genre'].tolist()

def get_genres(genre):
    genre_set = []

    for i in genre:
        
        genres = i.split(',')
        
        for g in genres:

            g = g.strip()

            if g not in genre_set:
                genre_set.append(g)

        idx2genre = dict(enumerate(genre_set))
        genre2idx = {g : idx for idx, g in idx2genre.items()}
    
    return idx2genre, genre2idx

In [33]:
idx2genre, genre2idx = get_genres(genre)

In [34]:
def count_genre(genre, genre2idx):
    
    genre_counts = {genre : 0 for genre in genre2idx.keys()}
    
    for i in genre:
        
        genres = i.split(',')
        
        for g in genres:

            g = g.strip()
            
            genre_counts[g] += 1
            
    return genre_counts

In [35]:
genre_counts = count_genre(genre, genre2idx)

In [114]:
def encode_genre(genre, genre2idx):
            
    genre = genre.split(',')
    encoded_genre = torch.LongTensor([genre2idx[g.strip()] for g in genre])
        
    return encoded_genre

In [39]:
encoded_genres = encode_genres(genre, genre2idx)

## Process Plot and Build Vocab

In [40]:
import re
from tqdm.notebook import tqdm

In [41]:
plots = df['plot']

In [42]:
def reg_remove(plot):
    remove_non_words = re.compile(r'[^\w -]')
    clean = re.sub(remove_non_words, '', plot)
    return clean


In [43]:
def build_vocab(plots):
    
    vocab = {}
    processed_plots = []
    
    for plot in tqdm(plots):

        plot = reg_remove(plot.lower()).split(' ')
        plot.insert(0, '<start>')
        plot.append('<end>')
        
        for token in plot:

            if token not in vocab:
                vocab[token] = len(vocab) +1 
        
        processed_plots.append(plot)
        
    idx2wrd = {idx : wrd for wrd,idx in vocab.items()}
    
    return vocab, idx2wrd, processed_plots
    


In [44]:
wrd2idx, idx2wrd, processed_plots = build_vocab(plots)

HBox(children=(FloatProgress(value=0.0, max=6398.0), HTML(value='')))




In [46]:
from torch.nn.utils.rnn import pad_sequence

In [119]:
def encode_plot(plot, wrd2idx):
    
    encoded_plot = []
    
    for token in plot:
        
        if token in wrd2idx:
            encoded_plot.append(wrd2idx[token])
            
        else:
            encoded_plot.append(len(wrd2idx)+1)
            
    return encoded_plot

In [48]:
def encode_plots(plots, wrd2idx):
    
    encoded = []
    
    for i in tqdm(plots):
        encoded.append(torch.LongTensor(encode(wrd2idx, i)))
        
    return pad_sequence(encoded,batch_first=True)

In [49]:
encoded = encode_plots(processed_plots, wrd2idx)

HBox(children=(FloatProgress(value=0.0, max=6398.0), HTML(value='')))




## Process Images

In [140]:
import os
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
import torch

transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

invTrans = transforms.Compose([
                                transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 
                                                     std=[1/0.229, 1/0.224, 1/0.225]),
                               ])

In [182]:
def process_image(filename, from_path):
    
    if from_path == False:
        
        input_image = Image.open(filename)
        transformed = transform(input_image)
        filename = filename.split('/')[-1][:-5]
        filename = 'data/processed_posters/{}-processed.jpeg'.format(filename)

        output_image(transformed, filename)
        return transformed

    else:
        
        transformed = transform(Image.open(filename))
        return transformed
    
    

In [151]:
def output_image(image, filename):
    
    image = ToPILImage()(invTrans(image))
    image.save(filename)
    
    

In [145]:
processed = torch.stack(processed_images)

RuntimeError: stack expects a non-empty TensorList

In [146]:
film_ids = [i[13:-5] for i in df['poster_path'].tolist()]

In [147]:
from torch.utils.data import Dataset, DataLoader

In [170]:
dataset = film_dataset(film_ids, encoded, processed, encoded_genres)

In [210]:
dataset[10]

{'film_id': '11-the-american-president',
 'plot': tensor([  1, 475, 476,  24, 477,  22, 478,  71,  11, 479, 480,  34, 416,  71,
          22, 417, 418, 481,   2, 482,  14,   2, 383, 220,  44,  33, 483,   8,
           9, 484,  14,  22, 485,  11, 486, 487, 488,  31,  24, 365,  32, 489,
         490, 491, 492,   2, 493, 494, 495, 496,  97,  53, 497, 498, 499, 481,
         500, 501,  14, 473,  44,  91, 408, 502, 503,  27, 504, 505,   8, 407,
         115,   2, 506, 507, 508,   8,   9,   2, 509, 510,  97,  22, 511, 416,
          71, 512,  11, 513,  14, 492, 514,  22, 515,  97,  11, 516, 494, 517,
         518, 519, 520,   8, 521,  53, 522, 115,  22, 523, 524, 164,  32, 525,
          25,  21,  10,  22, 526, 486,  34, 527, 221, 528, 166, 529, 530,  90,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   

## Decode 

In [213]:
def decode_plot(idx2wrd, plot):
    
    plot = [int(i) for i in list(plot)]
    decoded = [idx2wrd[i] for i in plot if i != 0]
    
    return ' '.join(decoded[1:-1])

In [218]:
def decode_genre(genre, idx2genre):
    
    genre = [int(i) for i in list(genre)]
    decoded = [idx2genre[i] for i in genre]
    return decoded

In [219]:
decode_genre(dataset[1]['genre'], idx2genre)

['Adventure', 'Comedy', 'Family', 'Fantasy']

In [251]:
def view_image(image):
    invTrans = transforms.Compose([
                                transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 
                                                     std=[1/0.229, 1/0.224, 1/0.225]),
                               ])
    pil_image = ToPILImage()(invTrans(image))
    return pil_image


In [254]:
image = view_image(dataset[23]['poster'])

In [157]:

for col, row in df.iterrows():
    print(row)
    genre = encode_genre(row['genre'], genre2idx)
    plot = encode_plot(row['plot'], wrd2idx)
    print(row['poster_path'])
    poster = process_image(row['poster_path'])
    break

id                                                             1
title                                                  Toy Story
genre              Animation, Adventure, Comedy, Family, Fantasy
imdb_link                                              tt0114709
plot           A little boy named Andy loves to be in his roo...
poster_path                        data/posters/1-toy-story.jpeg
Name: 0, dtype: object
data/posters/1-toy-story.jpeg


In [180]:
class film_dataset(Dataset):
    
    def __init__(self, df, wrd2idx, genre2idx):
        
        self.film_id = []
        self.genre = []
        self.plot = []
        self.poster = []
        self.failed = []
        
        self.wrd2idx = wrd2idx
        self.genre2idx = genre2idx
        
        for col, row in tqdm(df.iterrows(), total=len(df)):
            self.film_id.append(row['id'])
            self.genre.append(encode_genre(row['genre'], genre2idx))
            self.plot.append(encode_plot(row['plot'], wrd2idx))
            self.poster.append(process_image(row['poster_path'], True))
            self.failed.append(row)
            
    
    def __getitem__(self, idx):
        
        return {
            'film_id' : self.film_id[idx],
            'plot'    : self.plot[idx],
            'poster'  : self.poster[idx],
            'genre'   : self.genre[idx]
        }
        
    def __len__(sef):
        return len(film_id)
    

In [184]:
dataset = film_dataset(df[:50], wrd2idx, genre2idx)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [171]:
class film_dataset(Dataset):
    
    def __init__(self, film_id, plot, poster, genre):
        
        self.film_id = film_id
        self.plot = plot
        self.poster = poster
        self.genre = genre
        
    def __getitem__(self, idx):
        
        return {
            'film_id' : self.film_id[idx],
            'plot'    : self.plot[idx],
            'poster'  : self.poster[idx],
            'genre'   : self.genre[idx]
        }
        
    def __len__(sef):
        return len(film_id)

In [186]:
df

Unnamed: 0,id,title,genre,imdb_link,plot,poster_path
0,1,Toy Story,"Animation, Adventure, Comedy, Family, Fantasy",tt0114709,A little boy named Andy loves to be in his roo...,data/posters/1-toy-story.jpeg
1,2,Jumanji,"Adventure, Comedy, Family, Fantasy",tt0113497,After being trapped in a jungle board game for...,data/posters/2-jumanji.jpeg
2,3,Grumpier Old Men,"Comedy, Romance",tt0113228,Things don't seem to change much in Wabasha Co...,data/posters/3-grumpier-old-men.jpeg
3,4,Waiting to Exhale,"Comedy, Drama, Romance",tt0114885,This story based on the best selling novel by ...,data/posters/4-waiting-to-exhale.jpeg
4,5,Father of the Bride Part II,"Comedy, Family, Romance",tt0113041,"In this sequel to ""Father of the Bride"", Georg...",data/posters/5-father-of-the-bride-part-ii.jpeg
...,...,...,...,...,...,...
6403,50842,The Boss of It All,Comedy,tt0469754,The owner of an IT firm wants to sell up. The ...,data/posters/50842-the-boss-of-it-all.jpeg
6404,50851,Cocaine Cowboys,"Documentary, Crime, History",tt0380268,"In the 1980s, ruthless Colombian cocaine baron...",data/posters/50851-cocaine-cowboys.jpeg
6405,50872,Ratatouille,"Animation, Adventure, Comedy, Family, Fantasy",tt0382932,A rat named Remy dreams of becoming a great Fr...,data/posters/50872-ratatouille.jpeg
6406,50912,"Paris, je t'aime","Comedy, Drama, Romance",tt0401711,"Paris, je t'aime is about the plurality of cin...",data/posters/50912-paris-je-taime.jpeg
