<a href="https://colab.research.google.com/github/Amanux7/Pixel/blob/main/Pixel_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install transformers matplotlib diffusers accelerate datasets requests
print("Libraries installed!")

Looking in indexes: https://download.pytorch.org/whl/cu121
Libraries installed!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import math
from torchvision import transforms
from transformers import CLIPTokenizer
from PIL import Image
import requests
from io import BytesIO
from datasets import load_dataset
from huggingface_hub import login
from google.colab import drive
print("Imports complete!")

Imports complete!


In [None]:
dataset = load_dataset("mengcy/LAION-SG", split="train", streaming=True).take(10000)

# Preview raw sample
sample = next(iter(dataset))
print("Caption:", sample['caption_ori'])
img_url = sample['url']
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img.show()  # Pops up image
print("Dataset loaded!")

README.md: 0.00B [00:00, ?B/s]

Caption: Inauguration Eve- Signed By The Artist								 – Canvas Giclee – Limited Edition – 395 S/N – 36 x 48
Dataset loaded!


In [None]:
# Image transform
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
print("Transforms and tokenizer ready!")

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Transforms and tokenizer ready!


In [None]:
def batched_preprocess(examples):
    images = []
    texts = []
    for url, caption in zip(examples['url'], examples['caption_ori']):
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content)).convert("RGB")
            img_tensor = image_transform(img)
            text_tokens = tokenizer(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt")['input_ids'].squeeze()
            images.append(img_tensor)
            texts.append(text_tokens)
        except:
            pass  # Skip invalid
    return {'image': images, 'text': texts}  # Dict of lists, empty if no valid
print("Preprocess function defined!")

Preprocess function defined!


In [None]:
processed_dataset = dataset.map(batched_preprocess, batched=True, batch_size=10, remove_columns=dataset.column_names)
processed_dataset = processed_dataset.filter(lambda x: x['image'] is not None)

# Preview processed
sample = next(iter(processed_dataset))
print("Text Tokens:", sample['text'])
print("Image Shape:", sample['image'].shape)
print("Min/Max:", sample['image'].min(), sample['image'].max())
print("Preprocessing complete!")

Text Tokens: tensor([49406, 16805,  1866,   268,  3163,   638,   518,  2456,  1224,  7483,
        27900,  2592,  1224,  3472,  3062,  1224,   274,   280,   276,   338,
          270,   333,  1224,   274,   277,   343,   275,   279, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407])
Image Shape: torch.Size([3, 256, 256])
Min/Max: tensor(-0.7882) tensor(0.9843)
Preprocessing complete!


In [None]:
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

print("Schedules defined!")

Schedules defined!


In [None]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_dim=256):
        super().__init__()
        self.time_dim = time_dim
        self.time_mlp = nn.Sequential(
            nn.Linear(self.time_dim, self.time_dim),
            nn.SiLU(),
            nn.Linear(self.time_dim, self.time_dim)
        )
        self.time_proj_64 = nn.Conv2d(time_dim, 64, 1)
        self.time_proj_128 = nn.Conv2d(time_dim, 128, 1)
        self.time_proj_256 = nn.Conv2d(time_dim, 256, 1)
        self.down1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bottleneck = nn.Conv2d(128, 256, 3, padding=1)
        self.up1 = nn.Conv2d(256 + 128, 128, 3, padding=1)
        self.up2 = nn.Conv2d(128 + 64, 64, 3, padding=1)
        self.out = nn.Conv2d(64, out_channels, 3, padding=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=t.device).float() / channels))
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).float()
        t_emb = self.pos_encoding(t, self.time_dim)
        t_emb = self.time_mlp(t_emb)
        t_emb = t_emb[(..., ) + (None, ) * 2]
        d1 = F.relu(self.down1(x) + self.time_proj_64(t_emb))
        d1_pool = F.max_pool2d(d1, 2)
        d2 = F.relu(self.down2(d1_pool) + self.time_proj_128(t_emb))
        d2_pool = F.max_pool2d(d2, 2)
        b = F.relu(self.bottleneck(d2_pool) + self.time_proj_256(t_emb))
        b_up = F.interpolate(b, scale_factor=2, mode='bilinear', align_corners=False)
        u1 = F.relu(self.up1(torch.cat([b_up, d2], dim=1)) + self.time_proj_128(t_emb))
        u1_up = F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False)
        u2 = F.relu(self.up2(torch.cat([u1_up, d1], dim=1)) + self.time_proj_64(t_emb))
        return self.out(u2)

model = SimpleUNet()
print("Model defined!")

Model defined!


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Move schedules to device
betas = betas.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)
alphas_cumprod_prev = alphas_cumprod.to(device)
sqrt_recip_alphas = sqrt_recip_alphas.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)
posterior_variance = posterior_variance.to(device)

# Test forward
dummy_x = torch.randn(1, 3, 256, 256).to(device)
dummy_t = torch.tensor([0]).to(device)
output = model(dummy_x, dummy_t)
print("Output Shape:", output.shape)

Output Shape: torch.Size([1, 3, 256, 256])


In [None]:
def collate_fn(batch):
    batch = [item for item in batch if 'image' in item and item['image'] is not None]
    if not batch:
        return None
    images = torch.stack([item['image'] for item in batch])
    return images  # Unconditional for now

dataloader = DataLoader(processed_dataset, batch_size=4, collate_fn=collate_fn, num_workers=0)
print("DataLoader ready!")

DataLoader ready!


In [None]:
num_epochs = 2
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    batch_count = 0
    for batch in dataloader:
        if batch is None:
            continue
        images = batch.to(device)
        t = torch.randint(0, T, (images.size(0),)).to(device)
        noise = torch.randn_like(images)
        noisy_images = q_sample(images, t, noise)
        predicted_noise = model(noisy_images, t)
        loss = F.mse_loss(predicted_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_count += 1
    if batch_count > 0:
        avg_loss = total_loss / batch_count
        print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch+1}/{num_epochs}: No valid batches")

NameError: name 'model' is not defined