Definicja modelu, wczytanie i przetwarzanie zbioru uczącego.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np


class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.images = os.listdir(img_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")
        target = image.copy()
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        return image, target  


transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
])


dataset = CustomImageDataset("learning_img_basic", transform=transform, target_transform=target_transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), 
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7), 
            nn.BatchNorm2d(64)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid() 
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)



Trenowanie modelu


In [None]:

num_epochs = 400
losses = []
for epoch in range(num_epochs):
    for data in dataloader:
        img, target = data
        
        output = model(img)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    losses.append(loss.item())


torch.save(model.state_dict(), 'model.pth')

model.load_state_dict(torch.load('model.pth'))



 

plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()


Wczytywanie zapisanego modelu


In [None]:
loadmodel = Autoencoder()
loadmodel.load_state_dict(torch.load('batch_normalization.pth'))

Testowanie modelu i generowanie obrazów

In [None]:


from PIL import Image

testdata= CustomImageDataset("testing_img_basic", transform=transform, target_transform=target_transform)

testdataset= CustomImageDataset("testing_img_basic", transform=transform, target_transform=target_transform)
testdataloader = DataLoader(testdataset, batch_size=1, shuffle=True)


sumloss=0
with torch.no_grad():
    for data in testdataloader:
        img, target = data
        output = model(img)
        loss = criterion(output, target)
        orgimage=target.squeeze().detach().numpy()
        plt.imshow(np.transpose(orgimage, (1, 2, 0)))
        plt.title('Original Image')
        plt.axis('off')
        plt.show()
        sumloss+=loss.item()
        print(f'Loss: {loss.item():.4f}')
        plt.imshow(np.transpose(output.squeeze().detach().numpy(), (1, 2, 0)))
        plt.title('Colorized Image')
        plt.axis('off')
        plt.show()
avgloss=sumloss/len(testdataset)  
print(f'Average Loss: {avgloss:.4f}')

testimage= Image.open("different_size_img.jpg").convert("L")
testimage= transform(testimage)
testresult= loadmodel(testimage.unsqueeze(0))

plt.imshow(np.transpose(testresult.squeeze().detach().numpy(), (1, 2, 0)))
plt.title('Colorized Image')
plt.axis('off')
plt.show()