In [15]:
from PIL import Image
import cv2
import os
import matplotlib.pyplot as plt
import random
import scipy.io
import pandas as pd
import torch
import torchvision
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
device = 'cpu' # увы

In [8]:

mat = scipy.io.loadmat('./devkit/cars_train_annos.mat')
fname_to_class = {fname:cl-1 for fname, cl in zip([i[0] for i in mat['annotations'][0]['fname']], 
                                                [i[0][0] for i in mat['annotations'][0]['class']])}
cars_meta = scipy.io.loadmat('./devkit/cars_meta.mat')
id_to_car = {idx: car[0] for idx, car in enumerate(cars_meta['class_names'][0])}

In [5]:
ADD_PATH = './cars_train'

In [9]:
val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize(
        size=(224, 224)
    ),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

class CropClassifDataset(torch.utils.data.Dataset):
    def __init__(self, cars_items, transforms):
        self.cars = cars_items
        self.transforms = transforms
        
    def __len__(self):
        return len(self.cars)

    def __getitem__(self, idx):
        filename, cl_id = self.cars[idx]
        image = cv2.imread(os.path.join(ADD_PATH, filename))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transforms:
            image = self.transforms(image)
        sample = {'image': image, 'label': cl_id}
        return sample

In [10]:
items = list(fname_to_class.items())
random.shuffle(items)
train_items = items[:int(len(items) * 0.8)]
val_items = items[int(len(items) * 0.8):]

train_dataset = CropClassifDataset(train_items, val_transforms)
val_dataset = CropClassifDataset(val_items, val_transforms)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=None, pin_memory=True, drop_last = True)
valid_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=None, pin_memory=True)

In [13]:
from timm.scheduler import TanhLRScheduler
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from timm.scheduler import TanhLRScheduler
from lightning.pytorch.callbacks import LearningRateMonitor
from torchvision.models import resnet50

class CarClassifier(pl.LightningModule):
    def __init__(self, class_dict, learning_rate, emb_size = 512):
        super().__init__()
        self.learning_rate = learning_rate
        
        self.class_dict = class_dict
        
        self.model = resnet50()
        
        self.model = resnet50(pretrained=True)
        
        self.model.fc = torch.nn.Sequential(
                            torch.nn.Linear(in_features=2048, out_features=emb_size),
                            torch.nn.ReLU(inplace=False),
                            torch.nn.Linear(in_features=emb_size, out_features=len(class_dict)))
                        
        self.criterion = torch.nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        images = batch['image']
        labels = batch['label'].to(torch.long)
        preds = self.model(images)
        loss = self.criterion(preds, labels)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images = batch['image']
        labels = batch['label'].to(torch.long)
        preds = self.model(images)
        loss = self.criterion(preds, labels)
        self.log("validation_loss", loss, sync_dist=True)
        self.log("validation_accuracy", torch.sum(torch.argmax(preds, dim = 1) == labels).item() / torch.tensor(labels.shape).item(), sync_dist=True)
        
    def forward(self, images):
        if len(images.shape) == 4:
            preds = self.model(images) 
        else:
            preds = self.model(images.unsqueeze(0))
        preds = [self.class_dict[i.argmax().item()] for i in preds]
        return preds
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return [optimizer]

In [23]:
pl_model = CarClassifier(id_to_car, 3e-4)
checkpoint_callback = ModelCheckpoint(monitor='validation_accuracy', mode='max', save_top_k=3)
# last_checkpoint = ModelCheckpoint(mode='max', monitor='time_log', save_top_k=1)

early_stopping = EarlyStopping(monitor="validation_loss", mode="min", patience=2)
lr_monitor = LearningRateMonitor(logging_interval='step')

# train model
trainer = pl.Trainer(max_epochs=20, accelerator='cpu', devices = 1, strategy='auto', callbacks=[checkpoint_callback, early_stopping, lr_monitor])

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
