In [12]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import numpy as np
import random
import os
import cv2
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm


In [7]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [9]:
data1 = np.load('../data/data0.npy')
lab1 = np.load('../data/lab0.npy')
data2 = np.load('../data/data1.npy')
lab2 = np.load('../data/lab1.npy')
data3 = np.load('../data/data2.npy')
lab3 = np.load('../data/lab2.npy')

In [10]:
data_cumulative = np.concat([data1,data2,data3])
labels_cumulative = np.concat([lab1,lab2,lab3])
labels = torch.tensor(labels_cumulative, dtype=torch.long)

In [13]:
def process_images(data_cumulative):
    processed_images = []
    for img in data_cumulative:
        _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        kernel = np.ones((2, 1), np.uint8)
        binary = cv2.dilate(binary, kernel, iterations=1)
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
        digit_regions = []
        min_area = 20        
        for i in range(1, num_labels):
            x = stats[i, cv2.CC_STAT_LEFT]
            y = stats[i, cv2.CC_STAT_TOP]
            w = stats[i, cv2.CC_STAT_WIDTH]
            h = stats[i, cv2.CC_STAT_HEIGHT]
            area = stats[i, cv2.CC_STAT_AREA]
            
            if area > min_area:
                digit_regions.append((x, y, w, h))
        digit_regions.sort(key=lambda x: x[0])
        digit_images = []
        for x, y, w, h in digit_regions:
            digit = binary[y:y+h, x:x+w]
            digit_with_border = cv2.copyMakeBorder(digit, 10, 10, 10, 10, 
                                                   cv2.BORDER_CONSTANT, value=0)
            resized_digit = cv2.resize(digit_with_border, (28, 28))
            digit_images.append(resized_digit)
        
        if digit_images:
            concatenated_image = np.concatenate(digit_images, axis=1)
            concatenated_image = cv2.resize(concatenated_image, (224, 224))
        else:
            concatenated_image = np.zeros((224, 224), dtype=np.uint8)
        
        image_tensor = torch.tensor(concatenated_image / 255.0, dtype=torch.float32).unsqueeze(0)
        processed_images.append(image_tensor)
    
    return processed_images

processed_images = process_images(data_cumulative)

In [16]:
class ResNetForClassification(nn.Module):
    def __init__(self, num_classes):
        super(ResNetForClassification, self).__init__()
        self.resnet = models.resnet18(weights='DEFAULT')
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

class SumDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.long)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

images = processed_images 
labels = labels_cumulative
dataset = SumDataset(images, labels, transform=transform)
test_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [15]:
model = torch.load("model.pth")

  model = torch.load("model.pth")


In [17]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
print(f"Test Accuracy: {100 * correct / total:.2f}% ")


Test Accuracy: 90.36% 
