**This notebook is dedicated to build a feedforward neural net to classify the actors.**

The dataset fed to the model is the set of the resized equalized grayscale images.

# Setup the dataset

In [1]:
# !pip install pytorch_lightning
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader,random_split
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

import pandas as pd
import numpy as np
import cv2 as cv

In [None]:
# Step 1:  Downloading the dataset.

! wget https://datahack-prod.s3.amazonaws.com/test_zip/test_Bh8pGW3.zip
! unzip test_Bh8pGW3.zip
! wget https://datahack-prod.s3.amazonaws.com/train_zip/train_DETg9GD.zip
! unzip train_DETg9GD.zip

In [3]:
# Step 2: Define a pl DataModule

class ImagesLoader(pl.LightningDataModule):

    def __init__(self,classes:dict,batch_size=32):
        super(ImagesLoader,self).__init__()
        self.batch_size=batch_size
        self.classes=classes

    def prepare_data(self,num_imgs=1e10,img_size=(100,100)):
        # Training data
        train_csv=pd.read_csv("train.csv")
        self.train_set=[]
        # Only get num_imgs image due to transformation complications
        for i in range(min(num_imgs,train_csv.shape[0])):
            img=cv.imread(f"Train/{train_csv.loc[i,'ID']}",0)
            img=cv.equalizeHist(img)
            img=cv.resize(img,img_size)
            # img=cv.Canny(img,100,200)
            self.train_set.append((torch.tensor(img).float()/256.,train_csv.loc[i,'Class']))

    def setup(self):
        # Which classes to keep in the train & val datasets
        data=[]
        keys=list(self.classes.keys())
        for ex in self.train_set:
            if ex[1] in keys:
                data.append((ex[0],self.classes[ex[1]]))
        # Split the data 80% and 20%. The latter for validation
        thres=8*len(data)//10
        self.train,self.val=random_split(data,[thres,len(data)-thres])

    def train_dataloader(self):
        return DataLoader(self.train,self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val,self.batch_size)

# Binary Classification

1- MLP model to discriminate between young and old actors.

2-  MLP model to discriminate between middle-aged and old actors.

3- Combine the models in a single one.

In [61]:
# Step 1: Train a 1e4, 1e3, 500, 1 with sigmoid output and softmargin loss

class MLP(LightningModule):

    def __init__(self,in_dim):
        super().__init__()

        self.model=nn.Sequential(
                            nn.Flatten(),

                            nn.Linear(in_dim,1000),
                            nn.PReLU(1000),

                            nn.Linear(1000,500),
                            nn.PReLU(500),
                            
                            nn.Linear(500,500),
                            nn.Softmax(),
                            
                            nn.Linear(500,100),
                            nn.PReLU(100),
                            
                            nn.Linear(100,200),
                            nn.Tanh(),
                            
                            nn.Linear(200,50),
                            nn.PReLU(50),                                                                  
                            
                            nn.Linear(50,1),
                            nn.PReLU()
                                 )
        self.cost=nn.SoftMarginLoss()

    def forward(self,x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=1e-9,weight_decay=1)

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x).flatten()
        return self.cost(y_hat,y.float())

    def validation_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x).flatten()
        return self.cost(y_hat,y.float())

    def validation_epoch_end(self,outs):
        print(f'Validation for this epoch: \t\t{max(outs)}')

In [43]:
data=ImagesLoader({'YOUNG':-1,'OLD':1})
data.prepare_data(img_size=(100,100))
data.setup()

In [None]:
# ML model
model=MLP(10000)
trainer=Trainer(gpus=-1,progress_bar_refresh_rate=20)
trainer.fit(model,data)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type           | Params
-----------------------------------------
0 | model | Sequential     | 10.8 M
1 | cost  | SoftMarginLoss | 0     
-----------------------------------------
10.8 M    Trainable params
0         Non-trainable params
10.8 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Validation for this epoch: 		0.7073864340782166


  input = module(input)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717204213142395


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172040939331055


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172039747238159


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172038555145264


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172037959098816


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717203676700592


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172035574913025


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172034978866577


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172033786773682


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172031998634338


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172031402587891


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172030210494995


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.71720290184021


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172027826309204


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172027230262756


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172025442123413


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172024250030518


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717202365398407


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172022461891174


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172021269798279


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172020077705383


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172019481658936


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717201828956604


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172017097473145


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172015905380249


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172014713287354


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172013521194458


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172012329101562


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172011733055115


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172010540962219


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172008752822876


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172008752822876


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717200756072998


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172006368637085


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.717200517654419


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172003984451294


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172002792358398


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172001600265503


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7172000408172607


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7171999216079712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7171998023986816


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7171997427940369


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation for this epoch: 		0.7171996235847473


**To be continued**