# IMPORTES Y DEFINICION DE ALGUNAS VARIABLES NECESARIAS

In [1]:
import pytorch_lightning as pl
import pandas as pd
import cv2
import os 
from torch import nn
from torch.utils.data import Dataset ,DataLoader
import numpy as np
import torch
from sklearn.model_selection import train_test_split 
import torchvision
import torchmetrics as metrics
import matplotlib.image as mpimg
from PIL import Image
import glob
import matplotlib.pyplot as plt

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
IMG_SIZE = 64
TRAIN_PATH = "../input/iais22-birds/birds/birds"
TEST_PATH= "../input/iais22-birds/submission_test/submission_test"
CLASSES = 400


# DEFINICIÓN DE LA CNN

Procedemos a implementar la Red Neuronal Convolucional descrita en nuestro artículo *NOTA: Tras la defensa, se realizaron cambios en las capas nn.Linear añadiendo una más y ajustando mejor los valores de sus entradas y salidas.

In [2]:
class BirdsModel(pl.LightningModule):
    def __init__(self):
      #image_size = 64
        super().__init__()
        self.cnv = nn.Conv2d(3,128,5,4)
        self.rel = nn.ReLU()
        self.bn = nn.BatchNorm2d(128)
        self.mxpool = nn.MaxPool2d(4)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(1152,1712)
        self.fc2 = nn.Linear(1712,1024)
        self.fc3 = nn.Linear(1024,756)
        self.fc4 = nn.Linear(756,CLASSES)
        self.softmax = nn.Softmax()
        self.accuracy = metrics.Accuracy()

    def forward(self,x):
        out = self.cnv(x)
        out = self.rel(out)
        out = self.bn(out)
        out = self.mxpool(out)
        out = self.flat(out)
        out = self.rel(self.fc1(out))
        out = self.rel(self.fc2(out))
        out = self.rel(self.fc3(out))
        out = self.fc4(out)
        return out

    def loss_fn(self,out,target):
        return nn.CrossEntropyLoss()(out.view(-1,CLASSES),target)
    
    def configure_optimizers(self):
        LR = 1e-3
        optimizer = torch.optim.AdamW(self.parameters(),lr=LR)
        return optimizer
    
    def predict(self, x):
        with torch.no_grad():
            y_hat = self(x)
            return torch.argmax(y_hat, axis=1)

    def training_step(self,batch,batch_idx):
        x,y = batch
        imgs = x.view(-1,3,IMG_SIZE,IMG_SIZE)
        labels = y.view(-1)
        out = self(imgs)
        loss = self.loss_fn(out,labels)
        out = nn.Softmax(-1)(out)
        logits = torch.argmax(out,dim=1)
        accu = self.accuracy(logits, labels)
        self.log('train_loss', loss, prog_bar=True)
        self.log('acc', accu, prog_bar=True)
        return loss       

    def validation_step(self,batch,batch_idx):
        x,y = batch
        imgs = x.view(-1,3,IMG_SIZE,IMG_SIZE)
        labels = y.view(-1)
        out = self(imgs)
        loss = self.loss_fn(out,labels)
        out = nn.Softmax(-1)(out) 
        logits = torch.argmax(out,dim=1)
        accu = self.accuracy(logits, labels)
        self.log('valid_loss', loss, prog_bar=True)
        self.log('train_acc_step', accu, prog_bar=True)
        return loss, accu


# Obtengo el dataset de imágenes
Además aplico transformaciones para ayudar al entrenamiento de la CNN

In [3]:
train_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((64,64)),
                                      torchvision.transforms.ToTensor(),
                                      torchvision.transforms.Normalize([0.4740, 0.4676, 0.4134], [0.2143, 0.2085, 0.2333])
                                     ])
dataset=torchvision.datasets.ImageFolder(root=TRAIN_PATH,transform=train_transform)
train_dataset,val_dataset = torch.utils.data.random_split(dataset,[48000,10388])
train_loader= torch.utils.data.DataLoader(train_dataset,batch_size=512,shuffle=True, pin_memory=True, num_workers=2)
val_loader= torch.utils.data.DataLoader(val_dataset,batch_size=256, num_workers=2)
batch=next(iter(train_loader))
imgs, labels= batch[0].to(device),batch[1].to(device)
imgs.shape, labels.shape

In [4]:
mod = BirdsModel().to(device)

# Entrenando el modelo

In [5]:
trainer = pl.Trainer(accelerator='gpu',
                     gpus=1 if str(device)=="cuda:0" else 0,
                     max_epochs=10
                    
)
trainer.fit(mod,train_loader,val_loader) 

In [17]:
state_dict = mod.state_dict()

# torch.save(object, filename). For the filename, any extension can be used
torch.save(state_dict, "bird_cnn.pth")

# CARGAR MODELO Y PREDECIR


Preparación del submission.csv

In [20]:
test_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((64,64)),
                                      torchvision.transforms.ToTensor(),
                                      torchvision.transforms.Normalize([0.4740, 0.4676, 0.4134], [0.2143, 0.2085, 0.2333])
                                     ])

#CustomDataset creado para realizar las predicciones, utiliza un listado con todas las rutas de cada imagen del conjunto de test
class Dataset(torch.utils.data.Dataset):
    def __init__(self, list_IDs):
        
        self.list_IDs = list_IDs

    def __len__(self):
        
        return len(self.list_IDs)

    def __getitem__(self, index):
        
        # Select sample
        ID = self.list_IDs[index]

        # Load data
        img = Image.open(ID)
        X = test_transform(img).unsqueeze(0)


        return X

In [21]:
classes=[]
ids=[]
test_image_paths = []

In [22]:
classes = os.listdir(TRAIN_PATH)
ids = os.listdir(TEST_PATH)

for data_path in glob.glob(TEST_PATH + '/*'):
    test_image_paths.append(data_path)

test_dataset=Dataset(test_image_paths)
classes=sorted(classes)
print(classes)

In [23]:
state_dict = torch.load("./bird_cnn.pth")

new_model =BirdsModel()
new_model.load_state_dict(state_dict)


In [24]:
test_loader= torch.utils.data.DataLoader(test_dataset,batch_size=2000, num_workers=2)

In [25]:
batch=next(iter(test_loader))
imgs = batch
test_imgs = torch.squeeze(imgs, 1)
test_imgs.shape

In [26]:
outputs = new_model(test_imgs)
_, predicted = torch.max(outputs, 1)
pred_classes=[]
for index in predicted:
    pred_classes.append(classes[index])
    
print('Predicted: ', ' '.join(f'{predicted[j]:5d}' for j in range(2000)))

In [28]:
ids_replace=[]
for id in ids:
    ids_replace.append(id.replace('.jpg',''))
dic={'Id':ids_replace,'Category':pred_classes}
df=pd.DataFrame(data=dic)
df

In [29]:
df.to_csv('submission.csv',index=False)