In [1]:
cd course_intro_ocr/task2

/home/silevichar/liza/course_intro_ocr/task2


In [2]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader
import cv2
import numpy as np
from IPython.display import clear_output, Image

import torch
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models

from tqdm import tqdm

In [3]:
class CustomDataset(Dataset):
    def __init__(self, dataset_helper):
        self.dataset_helper = dataset_helper

    def __len__(self):
        return self.dataset_helper.size()

    def __getitem__(self, index):
        image, label = self.dataset_helper.get_item(index)
        image = cv2.resize(image, (128, 128))
        image = (image - 127.5) / 255.0
        return image, label
    

train_data_path = 'train.lmdb'
test_data_path = 'test.lmdb'
ground_truth_path = 'gt.txt'

train_reader = LMDBReader(train_data_path)
train_reader.open()
train_dataset_helper = HWDBDatasetHelper(train_reader)

train_dataset_helper, val_dataset_helper = train_dataset_helper.train_val_split()
num_classes = train_dataset_helper.vocabulary.num_classes()

train_dataset = CustomDataset(train_dataset_helper)
val_dataset = CustomDataset(val_dataset_helper)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=8)

In [4]:
gpu_index = 2
device = torch.device(f'cuda:{gpu_index}')

class CenterLoss(nn.Module):
    def __init__(self, num_classes=num_classes, feat_dim=512):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(device))

    def forward(self, x, labels):
        batch_size = x.size(0)
        dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                   torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        dist_mat.addmm_(1, -2, x, self.centers.t())
        labels_expand = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels_expand.eq(torch.arange(self.num_classes).long().to(device).expand(batch_size, self.num_classes))
        dist = dist_mat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
        return loss

class ResNetWithCenterLoss(nn.Module):
    def __init__(self, num_classes):
        super(ResNetWithCenterLoss, self).__init__()
        self.network = models.resnet18()
        self.network.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_features = self.network.fc.in_features
        self.network.fc = nn.Identity()
        self.head = nn.Linear(num_features, num_classes)

    def forward(self, images):
        features = self.network(images)
        output = self.head(features)
        return features, output


In [5]:
model = ResNetWithCenterLoss(num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = torch.nn.CrossEntropyLoss()
center_loss = CenterLoss()
center_loss_optimizer = torch.optim.SGD(center_loss.parameters(), lr=0.5)

accuracy_list = []
model.load_state_dict(torch.load('resnet_epoch_34.pth'))
model.eval()

for epoch in range(35, 40):
    print(f'Epoch {epoch}:')
    model.train()
    alpha = 0.5
    for step, (images, labels) in enumerate(tqdm(train_loader)):
        features, outputs = model(images.unsqueeze(1).float().to(device))

        classification_loss = loss_function(outputs, labels.to(device))
        center_loss_value = center_loss(features, labels.to(device))
        total_loss = classification_loss + alpha * center_loss_value

        optimizer.zero_grad()
        center_loss_optimizer.zero_grad()
        total_loss.backward()

        for param in center_loss.parameters():
            param.grad.data *= (1. / alpha)

        optimizer.step()
        center_loss_optimizer.step()
    model.eval()
    correct_predictions = 0
    total_samples = 0
    progress_bar = tqdm(val_loader)

    with torch.no_grad():
        for step, (images, labels) in enumerate(progress_bar):
            _, predictions = model(images.unsqueeze(1).float().to(device))
            predicted_classes = torch.argmax(predictions, dim=1).cpu().numpy()
            correct_predictions += (predicted_classes == labels.cpu().numpy()).sum()
            total_samples += labels.size(0)

    accuracy = correct_predictions / total_samples

    print(f'Epoch {epoch}:')
    print(f'Accuracy: {accuracy}')
    accuracy_list.append(accuracy)
    torch.save(model.state_dict(), f'resnet_epoch_{epoch}.pth')

torch.save(model.state_dict(), 'checkpoint2.pth')

Epoch 35:


	addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
	addmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1630.)
  dist_mat.addmm_(1, -2, x, self.centers.t())
100%|██████████| 5036/5036 [16:20<00:00,  5.14it/s]
100%|██████████| 1260/1260 [01:46<00:00, 11.85it/s]


Epoch 35:
Accuracy: 0.9433423982600305
Epoch 36:


100%|██████████| 5036/5036 [16:06<00:00,  5.21it/s]
100%|██████████| 1260/1260 [01:44<00:00, 12.09it/s]


Epoch 36:
Accuracy: 0.947810222941349
Epoch 37:


100%|██████████| 5036/5036 [16:16<00:00,  5.16it/s]
100%|██████████| 1260/1260 [01:43<00:00, 12.14it/s]


Epoch 37:
Accuracy: 0.9506336399274599
Epoch 38:


100%|██████████| 5036/5036 [16:18<00:00,  5.15it/s]
100%|██████████| 1260/1260 [01:42<00:00, 12.25it/s]


Epoch 38:
Accuracy: 0.9501310096508115
Epoch 39:


100%|██████████| 5036/5036 [16:18<00:00,  5.15it/s]
100%|██████████| 1260/1260 [01:46<00:00, 11.78it/s]

Epoch 39:
Accuracy: 0.9525712486173789





In [6]:
predictions_file_path = 'pred.txt'

test_reader = LMDBReader(test_data_path)
test_reader.open()
test_dataset_helper = HWDBDatasetHelper(test_reader, prefix='Test')

test_dataset = CustomDataset(test_dataset_helper)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=8)

predictions = []
model.eval()
with torch.no_grad():
    for images, _ in tqdm(test_loader):
        _, logits = model(images.unsqueeze(1).float().to(device))
        predicted_classes = torch.argmax(logits, dim=1).cpu().numpy()
        predictions.extend(predicted_classes)

with open(predictions_file_path, 'w') as pred_file:
    for index, prediction in enumerate(predictions):
        name = test_dataset_helper.namelist[index]
        class_name = train_dataset_helper.vocabulary.class_by_index(prediction)
        print(name, class_name, file=pred_file)

base_path = Path().absolute().parent.parent
ground_truth = {}
with open(ground_truth_path) as gt_file:
    for line in gt_file:
        filename, class_name = line.strip().split()
        ground_truth[filename] = class_name

correct_predictions = 0
total_samples = len(ground_truth)

with open(predictions_file_path) as pred_file:
    for line in pred_file:
        filename, predicted_class = line.strip().split()
        if predicted_class == ground_truth.get(filename):
            correct_predictions += 1

score = correct_predictions / total_samples
print(f'Accuracy = {score:.4f}')

100%|██████████| 1517/1517 [02:10<00:00, 11.63it/s]


Accuracy = 0.9262
