Воспользуйтеесь инструкцией с https://www.analyticsvidhya.com/blog/2021/06/how-to-load-kaggle-datasets-directly-into-google-colab/ для скачки датасета

In [1]:
import os

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset


import math

import numpy as np

import torchvision
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, Grayscale
from torchmetrics import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

from PIL import Image

import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1d94e096b20>

In [3]:
UPSCALE_FACTOR = 4
CROP_SIZE = 512

In [4]:
from dataset import TrainDatasetFromFolder

In [5]:
train_set = TrainDatasetFromFolder("chest_xray/train", crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
trainloader = DataLoader(train_set, batch_size=2, num_workers=4, shuffle=True)



In [6]:
train_set.__getitem__(0)[0].shape, train_set.__getitem__(0)[1].shape

(torch.Size([1, 128, 128]), torch.Size([1, 512, 512]))

In [7]:
class D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    def forward(self, x):

        return self.layer(x)


class Discriminator(nn.Module):
    def __init__(self, img_size, in_channels=1):
        super().__init__()

        self.conv_1_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
        )

        self.block_1_1 = D_Block(64, 64, stride=2)  # stride= 2 if output 4x
        self.block_1_2 = D_Block(64, 128, stride=1)
        self.block_1_3 = D_Block(128, 128)

        self.conv_2_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
        )

        self.block_2_2 = D_Block(64, 128, stride=1)

        self.block3 = D_Block(256, 256, stride=1)
        self.block4 = D_Block(256, 256)
        self.block5 = D_Block(256, 512, stride=1)
        self.block6 = D_Block(512, 512)
        self.block7 = D_Block(512, 1024)
        self.block8 = D_Block(1024, 1024)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(1024 * img_size[0] * img_size[1] // 256, 100) # Change based on input image size
        self.fc2 = nn.Linear(100, 2)

        self.relu = nn.LeakyReLU(negative_slope=0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):

        x_1 = self.block_1_3(self.block_1_2(self.block_1_1(self.conv_1_1(x1))))
        x_2 = self.block_2_2(self.conv_2_1(x2))

        x = torch.cat([x_1, x_2], dim=1)
        x = self.block8(
            self.block7(self.block6(self.block5(self.block4(self.block3(x)))))
        )

        x = self.flatten(x)


        x = self.fc1(x)
        x = self.fc2(self.relu(x))

        return self.sigmoid(x)

In [8]:
class RWMAB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):

        x_ = self.layer1(x)
        x__ = self.layer2(x_)

        x = x__ * x_ + x

        return x


class ShortResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)])

    def forward(self, x):

        x_ = x.clone()

        for layer in self.layers:
            x_ = layer(x_)

        return x_ + x


class Generator(nn.Module):
    def __init__(self, in_channels=1, blocks=8):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1)

        self.short_blocks = nn.ModuleList(
            [ShortResidualBlock(64) for _ in range(blocks)]
        )

        self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0)

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, (3, 3), stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(64, 256, (3, 3), stride=1, padding=1),
            nn.PixelShuffle(2),  # Remove if output is 2x the input
            nn.Conv2d(64, 1, (1, 1), stride=1, padding=0),  # Change 64 -> 256
            nn.Sigmoid(),
        )

    def forward(self, x):

        x = self.conv(x)
        x_ = x.clone()

        for layer in self.short_blocks:
            x_ = layer(x_)

        x = torch.cat([self.conv2(x_), x], dim=1)

        x = self.conv3(x)

        return x

In [9]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Standard device selectoin
device

device(type='cuda')

In [10]:
netG = Generator(in_channels=1, blocks=2)
netD = Discriminator(
    [int(CROP_SIZE/UPSCALE_FACTOR), 
     int(CROP_SIZE/UPSCALE_FACTOR)],
     in_channels=1
)


gen = netG.to(device)
disc = netD.to(device)

In [11]:
feature_extractor = torchvision.models.mobilenetv3.mobilenet_v3_small(in_channels=1)
feature_extractor.features[0][0] = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [12]:
feature_extractor = feature_extractor.to(device)

In [13]:
optimizer_G = optim.Adam(gen.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer_D = optim.Adam(disc.parameters(), lr=1e-4, weight_decay=1e-5)
loss_function = torch.nn.L1Loss().to(device)
gan_loss = torch.nn.BCEWithLogitsLoss().to(device)
scaler = torch.cuda.amp.GradScaler()

In [14]:
results = {
    "d_loss":[],
    "g_loss":[],
    "d_score": [],
    "g_score": []
}

In [15]:
N_EPOCHS = 150 # 150 is good enough for our model. gives decent enough results

In [None]:
for epoch in range(1, N_EPOCHS + 1):
  train_bar = tqdm(trainloader)
  running_results = {'batch_sizes':0, 'd_loss':0,
                     "g_loss":0, "d_score":0, "g_score":0}

  metrics = [PeakSignalNoiseRatio(), StructuralSimilarityIndexMeasure()]

  netG.train()
  netD.train()
  for data in train_bar:
    batch_size = data[0].size(0)
    running_results['batch_sizes'] += batch_size

    lr_img, hr_img = data

    lr_img, hr_img = lr_img.to(device), hr_img.to(device)

    valid = Variable(torch.Tensor(np.ones((lr_img.shape[0], 2))), requires_grad=False).to(device)
    fake = Variable(torch.Tensor(np.zeros((lr_img.shape[0], 2))), requires_grad=False).to(device)

    pred_hr = gen(lr_img)

    content_loss = loss_function(pred_hr, hr_img)

    pred_features = feature_extractor(pred_hr)
    hr_features = feature_extractor(hr_img)

    feature_loss = 0.0

    for pred_feature, hr_feature in zip(pred_features, hr_features):
        feature_loss += loss_function(pred_feature, hr_feature)

    pred_real = disc(hr_img.detach(), lr_img)
    pred_fake = disc(pred_hr, lr_img)

    gan_loss_num = gan_loss(
        pred_fake - pred_real.mean(0, keepdim=True), valid
    )

    loss_G = content_loss * 0.1 + feature_loss * 0.1 + gan_loss_num

    optimizer_G.zero_grad()
    scaler.scale(loss_G).backward()
    scaler.step(optimizer_G)
    scaler.update()

    # Train Discriminator

    pred_real = disc(hr_img, lr_img)
    pred_fake = disc(pred_hr.detach(), lr_img)

    loss_real = gan_loss(pred_real - pred_fake.mean(0, keepdim=True), valid)
    loss_fake = gan_loss(pred_fake - pred_real.mean(0, keepdim=True), fake)

    loss_real_num = gan_loss(pred_real, valid)
    loss_fake_num = gan_loss(pred_fake, fake)

    loss_D = ((loss_real + loss_fake) / 2) + (
        (loss_real_num + loss_fake_num) / 2
    )

    optimizer_D.zero_grad()

    running_results['g_loss'] += loss_G.item() * batch_size
    running_results['d_loss'] += loss_D.item() * batch_size
    running_results['d_score'] += loss_real_num.item() * batch_size
    running_results['g_score'] += gan_loss_num.item() * batch_size
    metrics[0](hr_img.detach().cpu()*255, pred_hr.detach().cpu()*255).item()
    metrics[1](hr_img.detach().cpu()*255, pred_hr.detach().cpu()*255).item()

    ## Updating the progress bar
    train_bar.set_description(desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f PSNR: %2f SSIM: %2f" % (
        epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
        running_results['g_loss'] / running_results['batch_sizes'],
        running_results['d_score'] / running_results['batch_sizes'],
        running_results['g_score'] / running_results['batch_sizes'],
        metrics[0].compute().item(),
        metrics[1].compute().item()
    ))
  netG.eval()

[1/150] Loss_D: 1.4120 Loss_G: 0.7498 D(x): 0.4640 D(G(z)): 0.6988 PSNR: 5.405110 SSIM: 0.003313:   9%| | 224/2608 [59:

In [None]:
def plot_images(images):
  grid_img = torchvision.utils.make_grid(images.detach().cpu()[:16])
  grid_img = (grid_img.permute(1, 2, 0).numpy()*255)

  plt.figure(figsize=(20, 20))
  plt.imshow(grid_img.astype(np.uint8))

In [None]:
plot_images(fake_img)
plot_images(real_img)