In [None]:
!pip install torch torchvision torchaudio -q

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader,random_split,Dataset
from PIL import Image
import pandas as pd
import os,glob
import torchvision.transforms.functional as F

# Dataset 정의

In [None]:
img_dir = "/content/drive/MyDrive/DIV2K_519sampled"

crop_size = 64
upscale_factor = 2

hrtransform = transforms.Compose([
    transforms.CenterCrop(crop_size),
    transforms.ToTensor()
])

lrtransform = transforms.Compose([
    transforms.CenterCrop(crop_size),
    transforms.Resize(crop_size//upscale_factor,interpolation = Image.BICUBIC),
    transforms.Resize(crop_size,interpolation = Image.BICUBIC),
    transforms.ToTensor()
])

train_size = 363
valid_size = 104
test_size = 52

class mydataset(Dataset):
  def __init__(self,img_dir,start,end):
    super().__init__()
    self.image = glob.glob(os.path.join(img_dir,"*.png"))
    self.image = sorted(self.image)
    self.image = self.image[start:end]
    self.img_dir = img_dir
    self.hrtransforms = hrtransform
    self.lrtransforms = lrtransform

  def __getitem__(self,index):
    image = Image.open(self.image[index])
    target = image.copy()
    lr = self.lrtransforms(target)
    hr = self.hrtransforms(image)
    return lr,hr

  def __len__(self):
    return len(self.image)

train_dataset = mydataset(img_dir,0,train_size)
valid_dataset = mydataset(img_dir,train_size,train_size+valid_size)
test_dataset = mydataset(img_dir,train_size+valid_size,train_size+valid_size+test_size)

train_loader = DataLoader(train_dataset,batch_size =32,shuffle = True)
valid_loader = DataLoader(valid_dataset,batch_size =32,shuffle = False)
test_loader = DataLoader(test_dataset,batch_size =32,shuffle = False)


# Model 정의

In [None]:
class srcnn(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(3,64,9,1,4),
        nn.ReLU(inplace = True),
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(64,32,1,1,0),
        nn.ReLU(inplace = True),
    )
    self.layer3 = nn.Sequential(
        nn.Conv2d(32,3,5,1,2),
    )

  def forward(self,x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x

# Model 학습

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = srcnn().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)
epochs = 50

for epoch in range(epochs):
  model.train()
  total_loss = 0.0
  for image,target in train_loader:
    image = image.to(device)
    target = target.to(device)

    output = model(image)
    loss = criterion(output,target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")
  model.eval()
  val_loss = 0.0

  with torch.no_grad():
    for image, target in valid_loader:
      image = image.to(device)
      target = target.to(device)

      output = model(image)
      loss = criterion(output,target)

      val_loss += loss.item()

    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}")

Using device: cuda
Epoch 0, Loss: 0.0729
Epoch [1/50], Validation Loss: 0.0329
Epoch 1, Loss: 0.0263
Epoch [2/50], Validation Loss: 0.0197
Epoch 2, Loss: 0.0181
Epoch [3/50], Validation Loss: 0.0156
Epoch 3, Loss: 0.0131
Epoch [4/50], Validation Loss: 0.0109
Epoch 4, Loss: 0.0098
Epoch [5/50], Validation Loss: 0.0087
Epoch 5, Loss: 0.0079
Epoch [6/50], Validation Loss: 0.0076
Epoch 6, Loss: 0.0072
Epoch [7/50], Validation Loss: 0.0068
Epoch 7, Loss: 0.0063
Epoch [8/50], Validation Loss: 0.0060
Epoch 8, Loss: 0.0058
Epoch [9/50], Validation Loss: 0.0054
Epoch 9, Loss: 0.0053
Epoch [10/50], Validation Loss: 0.0057
Epoch 10, Loss: 0.0051
Epoch [11/50], Validation Loss: 0.0044
Epoch 11, Loss: 0.0044
Epoch [12/50], Validation Loss: 0.0042
Epoch 12, Loss: 0.0042
Epoch [13/50], Validation Loss: 0.0040
Epoch 13, Loss: 0.0040
Epoch [14/50], Validation Loss: 0.0039
Epoch 14, Loss: 0.0039
Epoch [15/50], Validation Loss: 0.0038
Epoch 15, Loss: 0.0038
Epoch [16/50], Validation Loss: 0.0038
Epoch 16

# Model 저장 및 불러오기

In [None]:
torch.save(model.state_dict(),"/content/drive/MyDrive/srcnn_checkpoint.pth")
new_model = srcnn().to(device)
new_model.load_state_dict(torch.load("srcnn_checkpoint.pth"))
new_model.eval()

srcnn(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU(inplace=True)
  )
  (layer3): Sequential(
    (0): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
)

In [None]:
save_dir = "/content/drive/MyDrive/SR_outputs"
os.makedirs(save_dir, exist_ok=True)

with torch.no_grad():
  for idx,batch in enumerate(test_loader):
    image = batch[0].to(device)
    output = model(image)
    for i in range(output.size(0)):
      output_img = output[i].cpu().clamp(0, 1)
      save_path = os.path.join(save_dir, f"SR_output_{idx * test_loader.batch_size + i + 1}.png")
      F.to_pil_image(output_img).save(save_path)