<a href="https://colab.research.google.com/github/JunHo06/HAI-Assignment/blob/main/Hai_summer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.utils.data as data
import glob
import PIL
from torchvision.transforms import Compose, CenterCrop, Resize, ToTensor

def getCropImg(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor()
    ])

def getLRimage(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor, PIL.Image.BICUBIC),
        Resize(crop_size, PIL.Image.BICUBIC),
        ToTensor()
    ])

class MyDataset(data.Dataset):
    def __init__(self, image_dir, start, end):
        super(MyDataset, self).__init__()
        self.png_files = sorted(glob.glob(image_dir + "/*.png"))[start:end]
        self.input_transform = getLRimage(64, 2)
        self.target_transform = getCropImg(64)

    def __getitem__(self, index):
        input_img = PIL.Image.open(self.png_files[index])
        target_img = input_img.copy()

        if self.input_transform:
            input_img = self.input_transform(input_img)
        if self.target_transform:
            target_img = self.target_transform(target_img)

        return input_img, target_img

    def __len__(self):
        return len(self.png_files)

In [None]:
import torch.nn as nn

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.layer1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4)
        self.layer2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0)
        self.layer3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.relu = nn.ReLU()


    def forward(self, x):
      x = self.layer1(x)
      x = self.layer2(x)
      x = self.layer3(x)

      return x

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNN().to(device)
epochs = 50
batch_size = 32
lr = 0.001
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

trainDataNum, validDataNum, testDataNum = 363, 104, 52

train_set = MyDataset('/content/drive/MyDrive/DIV2K_HR_519sampled', 0, trainDataNum)
valid_set = MyDataset('/content/drive/MyDrive/DIV2K_HR_519sampled', trainDataNum, trainDataNum + validDataNum)
test_set = MyDataset('/content/drive/MyDrive/DIV2K_HR_519sampled', trainDataNum + validDataNum, trainDataNum + validDataNum + testDataNum)

training_data_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
valid_data_loader = DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=False)
test_data_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

for epoch in range(epochs):
    for i, batch in enumerate(training_data_loader):
        input, target = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()






    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for i, batch in enumerate(valid_data_loader):
            input, target = batch[0].to(device), batch[1].to(device)
            output = model(input)
            loss = loss_fn(output, target)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(valid_data_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}")

In [None]:
import torch

# Save the model
torch.save(model.state_dict(), "srcnn_checkpoint.pth")

# Load the model
new_model = SRCNN().to(device)
new_model.load_state_dict(torch.load("srcnn_checkpoint.pth"))
new_model.eval()