In [1]:
import torch
import torch.nn as nn
torch.cuda.is_available()

True

In [2]:
# download datasets if not exist
from torchvision.datasets.oxford_iiit_pet import OxfordIIITPet

# root
root = './data'
# download
_ = OxfordIIITPet(root, download=True)

In [3]:
# nn.Conv2d를 이용한 패치 임베딩
import torch
import torch.nn as nn
# Reference
# https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/vision_transformer/custom.py

class PatchEmbed(nn.Module):
    """Split image into patches and then embed them.

    Parameters
    ----------
    img_size : int
        Size of the image (it is a square).

    patch_size : int
        Size of the patch (it is a square).

    in_chans : int
        Number of input channels.

    embed_dim : int
        The emmbedding dimension.

    Attributes
    ----------
    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches
        and their embedding.
    """
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size % patch_size == 0,\
            f"img_size({img_size} is not divisable by patch_size({patch_size}))"

        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


# test
img_size = 224
patch_size = 16
in_chans = 3
embed_dim = 768
x = torch.randn(1, in_chans, img_size, img_size)
model = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
out = model(x)
print(out.shape)

torch.Size([1, 196, 768])


In [4]:
# Transformer based model
class TransformerAutoEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_heads=12, num_layers=12):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, embed_dim))
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads
            ),
            num_layers=num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=embed_dim,
                nhead=num_heads
            ),
            num_layers=num_layers
        )
        self.proj = nn.Linear(embed_dim, in_chans * patch_size * patch_size)

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        memory = self.encoder(x)
        x = self.decoder(x, memory)
        x = self.proj(x)
        x = x.reshape(x.shape[0], in_chans, img_size, img_size)
        return x

# test
model = TransformerAutoEncoder()
out = model(torch.randn(1, 3, 224, 224))
print(out.shape)

torch.Size([1, 3, 224, 224])




In [5]:
# data augmentation
from torchvision.transforms import v2

# image size
IMG_SIZE = 224

train_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.uint8),
    v2.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
    v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.uint8),
    v2.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# test transforms
random_tensor = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
train_transform_result = train_transforms(random_tensor)
test_transform_result = test_transforms(random_tensor)
print(train_transform_result.shape)
print(test_transform_result.shape)


torch.Size([3, 224, 224])
torch.Size([3, 224, 224])


In [6]:
# dataloader
from torch.utils.data import DataLoader
from torchvision.datasets import OxfordIIITPet

# train_datasets
train_datasets = OxfordIIITPet(root, transform=train_transforms)
test_datasets = OxfordIIITPet(root, split="test", transform=test_transforms)

# check datasets length
print(len(train_datasets))
print(len(test_datasets))

BATCH_SIZE = 64

# dataloaders
train_dataloader = DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=16)
test_dataloader = DataLoader(test_datasets, batch_size=BATCH_SIZE, shuffle=False, num_workers=16)

# test dataloaders
for x, y in train_dataloader:
    print(x.shape, y)
    break

for x, y in test_dataloader:
    print(x.shape, y)
    break


3680
3669


torch.Size([64, 3, 224, 224]) tensor([13, 22, 14,  5, 12, 10,  3, 15,  6, 35,  2, 22, 30,  5, 12, 24, 10, 28,
        19, 26, 35, 10, 31, 19, 23, 16, 23, 12,  1, 36,  6, 12, 36, 26, 14, 18,
        31,  3, 25, 27, 29, 33, 21, 19, 16, 26, 22, 29, 34, 28, 31, 14,  0, 28,
        22, 11, 13, 20,  9, 24,  0, 32, 21,  1])
torch.Size([64, 3, 224, 224]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [7]:
# check device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


In [8]:
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import MSELoss

# 모델 선언
model = TransformerAutoEncoder().to(DEVICE)

# 옵티마이저, 스케줄러, 손실 함수 선언
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
criterion = MSELoss()

In [9]:
import torch
import torchmetrics
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

def train_fn(model, dataloader, criterion, optimizer, scheduler, device, scaler):
    model.train()
    train_loss = 0.0

    for images, _ in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, images)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()
        train_loss += loss.item()

    train_loss /= len(dataloader)
    print(f"Train Loss: {train_loss:.4f}")

def test_fn(model, dataloader, criterion, device):
    model.eval()
    test_loss = 0.0
    psnr = 0.0
    ssim = 0.0

    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure().to(device)

    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Testing"):
            images = images.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, images)

            test_loss += loss.item()
            psnr += psnr_metric(outputs, images).item()
            ssim += ssim_metric(outputs, images).item()

    test_loss /= len(dataloader)
    psnr /= len(dataloader)
    ssim /= len(dataloader)

    print(f"Test Loss: {test_loss:.4f}, PSNR: {psnr:.2f}, SSIM: {ssim:.4f}")
    return test_loss

def train_autoencoder(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, device, epochs, save_period):
    best_loss = float('inf')
    scaler = GradScaler()

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        train_fn(model, train_dataloader, criterion, optimizer, scheduler, device, scaler)
        
        if val_dataloader is not None:
            val_loss = test_fn(model, val_dataloader, criterion, device)
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), f"best_model_epoch_{epoch+1}.pth")
        
        if (epoch + 1) % save_period == 0:
            torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
    
    return model

In [10]:
# train
model = train_autoencoder(model, train_dataloader, test_dataloader, criterion, optimizer, scheduler, DEVICE, epochs=100, save_period=5)

Epoch 1/10


Training: 100%|██████████| 58/58 [00:20<00:00,  2.78it/s]


Train Loss: 1.3952


Testing: 100%|██████████| 58/58 [00:06<00:00,  9.15it/s]


Test Loss: 1.3701, PSNR: 12.21, SSIM: 0.0076
Epoch 2/10


Training:   7%|▋         | 4/58 [00:02<00:25,  2.14it/s]