In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
class DigitSumDataset(Dataset):
    def __init__(self, image_path, label_path):
        self.images = np.load(image_path)
        self.labels = np.load(label_path)
        self.images = torch.tensor(self.images, dtype=torch.float32) / 255.0
        self.labels = torch.tensor(self.labels, dtype=torch.float32)
        # self.images = self.images.unsqueeze(1) 

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [3]:
class ResNetModel(nn.Module):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.resnet = models.resnet50(weights='IMAGENET1K_V1')
        self.resnet.conv1 = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False
        )

        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            torch.nn.Linear(num_features, 128),
            torch.nn.ReLU()
        )

        self.fc = nn.Sequential(
            torch.nn.Linear(128, 1)
        )

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x

In [4]:
train_img_path = 'DL-Project/data0.npy'
train_label_path = 'DL-Project/lab0.npy'

val_img_path = 'DL-Project/data1.npy'
val_label_path = 'DL-Project/lab1.npy'

test_img_path = 'DL-Project/data2.npy'
test_label_path = 'DL-Project/lab2.npy'

train_dataset = DigitSumDataset(train_img_path, train_label_path)
val_dataset = DigitSumDataset(val_img_path, val_label_path)
test_dataset = DigitSumDataset(test_img_path, test_label_path)

In [5]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
img_model = ResNetModel().to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(img_model.parameters(), lr=0.001)

train_losses = []
val_losses = []

for epoch in range(10):
    img_model.train()
    for i, (images, labels) in enumerate(tqdm(train_loader)):
        images = images.unsqueeze(1).to(device)
        labels = labels.to(device)

        # print(images.shape)
        
        optimizer.zero_grad()
        outputs = img_model(images)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
    train_losses.append(loss.item())

    img_model.eval()
    val_loss = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(val_loader)):
            images = images.unsqueeze(1).to(device)
            labels = labels.to(device)

            outputs = img_model(images)
            loss = criterion(outputs.squeeze(), labels)
            val_loss += loss.item()
        val_losses.append(val_loss / len(val_loader))

    # save best model
    if val_losses[-1] == min(val_losses):
        torch.save(img_model.state_dict(), 'best_model.pth')

    print(f'Epoch {epoch} val loss: {val_loss / len(val_loader)}')

 40%|███▉      | 125/313 [02:57<03:28,  1.11s/it]

In [None]:
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.legend()
plt.show()