In [14]:
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import numpy as np


In [15]:
class CIFAR10Dataset(Dataset):
    def __init__(self, data_dir):
        self.data = []
        for batch_id in range(1, 6):
            with open(os.path.join(data_dir, f'data_batch_{batch_id}'), 'rb') as f:
                entry = pickle.load(f, encoding='bytes')
                self.data.append(entry[b'data'])
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = torch.tensor(self.data, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        return x, x  # Input = Target for pixel prediction

In [16]:
dataset_path = './datasets/cifar-10-batches-py'
batch_size = 16

# DataLoader
dataset = CIFAR10Dataset(dataset_path)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# RowLSTM block
class RowLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.input_conv = nn.Conv2d(input_size, 4 * hidden_size, (1, kernel_size), padding=(0, padding))
        self.hidden_conv = nn.Conv2d(hidden_size, 4 * hidden_size, (1, 1))
        self.hidden_dim = hidden_size

    def forward(self, x, h, c):
        B, _, H, W = x.size()
        outputs = []
        for i in range(H):
            x_row = x[:, :, i:i+1, :]
            gates = self.input_conv(x_row) + self.hidden_conv(h)
            o, f, i_gate, g = gates.chunk(4, dim=1)
            o = torch.sigmoid(o)
            f = torch.sigmoid(f)
            i_gate = torch.sigmoid(i_gate)
            g = torch.tanh(g)
            c = f * c + i_gate * g
            h = o * torch.tanh(c)
            outputs.append(h)
        return torch.cat(outputs, dim=2), h, c

In [17]:
class PixelRNN(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, output_classes=256, num_layers=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.output_classes = output_classes
        self.input_conv = nn.Conv2d(input_dim, hidden_dim, kernel_size=7, padding=3)
        self.lstm_layers = nn.ModuleList([RowLSTM(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.output_conv = nn.Conv2d(hidden_dim, input_dim * output_classes, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.size()
        h = torch.zeros(B, self.hidden_dim, 1, W, device=x.device)
        c = torch.zeros_like(h)

        x = self.input_conv(x.float() / 255.0)
        for lstm in self.lstm_layers:
            x, h, c = lstm(x, h, c)
        x = self.output_conv(x)
        x = x.view(B, C, self.output_classes, H, W)
        return x

In [18]:
def train(model, loader, device, epochs=5):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for images, _ in loader:
            images = images.to(device)
            targets = images

            outputs = model(images)

            loss = 0
            for c in range(3):
                loss += criterion(outputs[:, c], targets[:, c])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(loader):.4f}")

In [None]:
device = torch.device("cuda")
model = PixelRNN().to(device)

print("Starting training...")
train(model, loader, device, epochs=5)

Starting training...


CUDA available: True
Current device: 0
Device name: NVIDIA GeForce RTX 2080 Super
