In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import glob
import sys
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt


In [72]:
#uses gpu if one exists otherwise just cpu
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_FOLDER = '/Users/fizzausman/Desktop/BSRGAN-finnnn/results' #training images stored here
IMG_SIZE = 64
BATCH_SIZE = 16
NUM_WORKERS = 0
DIFFUSION_STEPS = 1000

EPOCHS = 100
MAX_BATCHES = 200           
LR = 2e-4

# what it creates
SAMPLE_BATCH = 8
OUT_DIR = './samples'
os.makedirs(OUT_DIR, exist_ok=True)

#reads images and uses magic methods in python
class ReadingImages(Dataset):
    def __init__(self, folder, transform=None):
        self.files = sorted(glob.glob(os.path.join(folder, "*.png")) + glob.glob(os.path.join(folder, "*.jpg")))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

#we have added random flipping, color jittering and all to benefit training. 
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(), #hair must not always be on left side
    transforms.ColorJitter(0.2,0.2,0.2,0.08),                         
    transforms.ToTensor(),      
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),  
])

dataset = ReadingImages(DATA_FOLDER, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)




In [69]:
#diffusion factors
betas = torch.linspace(1e-4, 0.02, DIFFUSION_STEPS, device=DEVICE)  # amount of noise added at this step
alphas = 1.0 - betas #how much of image remains
cum_alpha_bar = torch.cumprod(alphas, dim=0)                           # cumulative image left after many steps
alpha_bar_prev = torch.cat([torch.tensor([1.0], device=DEVICE), cum_alpha_bar[:-1]])

#time embedding
class SinCosPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        self.post_mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.SiLU(),
            nn.Linear(dim * 2, dim)
        )

    def making_embedding(self, t):
        half = self.dim // 2
        t = t.float() / float(DIFFUSION_STEPS)         # scale time to [0,1] 
        freqs = torch.exp(-math.log(20000) * torch.arange(half, device=t.device).float() / half)
        args = t[:, None] * freqs[None, :]
        embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if self.dim % 2 == 1:
            embedding = torch.cat([embedding, torch.zeros(t.size(0), 1, device=t.device)], dim=-1)

        embedding = self.post_mlp(embedding)
        return embedding
    
    def forward(self,t):
        return self.making_embedding(t)

In [73]:
#refines image features 
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None):
        super().__init__()

        self.path1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8 if out_ch >= 8 else 1, out_ch),
            nn.SiLU()
        )


        self.path2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8 if out_ch >= 8 else 1, out_ch),
            nn.SiLU()
        )
        self.time_proj = nn.Linear(time_emb_dim, out_ch) if time_emb_dim else None
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()


    def forward(self, x, t_emb=None):
        h = self.path1(x)

        if t_emb is not None and self.time_proj is not None:
            h = h + self.time_proj(t_emb)[:,:,None,None]

        h= self.path2(h)

        return h + self.skip(x)

In [74]:
class Downsample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.op = nn.Conv2d(ch, ch, kernel_size=4, stride=2, padding=1)
    def forward(self, x):
        return self.op(x)

class Upsample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, 3, padding=1)
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        return self.conv(x)

class UNetSmall(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, time_emb_dim=128, channel_mults=(1,2,4)):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinCosPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        chs = [base_ch * m for m in channel_mults]
        self.enc_blocks = nn.ModuleList()
        self.downs = nn.ModuleList()
        in_c = base_ch
        for out_c in chs:
            self.enc_blocks.append(ResidualBlock(in_c, out_c, time_emb_dim=time_emb_dim))
            self.downs.append(Downsample(out_c))
            in_c = out_c
        self.mid1 = ResidualBlock(in_c, in_c*2, time_emb_dim=time_emb_dim)
        self.mid2 = ResidualBlock(in_c*2, in_c, time_emb_dim=time_emb_dim)
        self.ups = nn.ModuleList()
        self.dec_blocks = nn.ModuleList()
        for out_c in reversed(chs):
            self.ups.append(Upsample(in_c))
            self.dec_blocks.append(ResidualBlock(in_c + out_c, out_c, time_emb_dim=time_emb_dim))
            in_c = out_c
        self.final_conv = nn.Sequential(
            nn.GroupNorm(num_groups=8 if in_c>=8 else 1, num_channels=in_c),
            nn.SiLU(),
            nn.Conv2d(in_c, in_ch, 3, padding=1)
        )

    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        h = self.init_conv(x)
        skips = []
        for enc, down in zip(self.enc_blocks, self.downs):
            h = enc(h, t_emb)
            skips.append(h)
            h = down(h)
        h = self.mid1(h, t_emb)
        h = self.mid2(h, t_emb)
        for up, dec, skip in zip(self.ups, self.dec_blocks, reversed(skips)):
            h = up(h)
            if skip.shape[-2:] != h.shape[-2:]:
                _,_,H,W = h.shape
                skip = skip[..., :H, :W]
            h = torch.cat([h, skip], dim=1)
            h = dec(h, t_emb)
        out = self.final_conv(h)
        return out


model = UNetSmall(in_ch=3, base_ch=64, time_emb_dim=128).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
mse = nn.MSELoss()


def forward_process(x0, t, alpha_bar):

    noise = torch.randn_like(x0)
    sqrt_ab = torch.sqrt(alpha_bar[t]).view(-1,1,1,1)
    sqrt_1m = torch.sqrt(1.0 - alpha_bar[t]).view(-1,1,1,1)
    xt = sqrt_ab * x0 + sqrt_1m * noise
    return xt, noise


@torch.no_grad()
def sample(model, shape, betas, alphas, alpha_bar, alpha_bar_prev, device):
    model.eval()
    x = torch.randn(shape, device=device)  

    for step in reversed(range(DIFFUSION_STEPS)):
        t = step
        ts = torch.full((shape[0],), t, device=device, dtype=torch.long)
        eps_pred = model(x, ts)

        a_t = alphas[t]               
        ab_t = alpha_bar[t]
        ab_prev = alpha_bar_prev[t]
        b_t = betas[t]

        x0_pred = (x - torch.sqrt(1.0 - ab_t) * eps_pred) / (torch.sqrt(ab_t) + 1e-8)

        coef1 = torch.sqrt(ab_prev) * b_t / (1.0 - ab_t)
        coef2 = torch.sqrt(a_t) * (1.0 - ab_prev) / (1.0 - ab_t)
        mean = coef1 * x0_pred + coef2 * x

        if step > 0:
            x = mean + torch.sqrt(b_t) * torch.randn_like(x)
        else:
            x = mean

    return x.clamp(-1, 1)

alpha_bar = cum_alpha_bar.to(DEVICE)
alpha_bar_prev = alpha_bar_prev.to(DEVICE)

print("Starting training on this device", DEVICE)
for epoch in range(EPOCHS):
    for batch_idx, x0 in enumerate(loader):
        if batch_idx >= MAX_BATCHES:
            break

        x0 = x0.to(DEVICE)
        t = torch.randint(0, DIFFUSION_STEPS, (x0.size(0),), device=DEVICE, dtype=torch.long)

        noise = torch.randn_like(x0)
        sqrt_ab = torch.sqrt(alpha_bar[t]).view(-1, 1, 1, 1)
        sqrt_1m = torch.sqrt(1.0 - alpha_bar[t]).view(-1, 1, 1, 1)
        xt = sqrt_ab * x0 + sqrt_1m * noise

        noise_pred = model(xt, t)

        loss = mse(noise_pred, noise)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch:03d} Batch {batch_idx:04d}  Loss: {loss.item():.6f}")

    with torch.no_grad():
        samples = sample(
            model,
            shape=(SAMPLE_BATCH, 3, IMG_SIZE, IMG_SIZE),
            betas=betas,
            alphas=alphas,
            alpha_bar=alpha_bar,
            alpha_bar_prev=alpha_bar_prev,
            device=DEVICE
        )

    samples_01 = (samples * 0.5) + 0.5
    grid = vutils.make_grid(samples_01, nrow=4, padding=2, normalize=False)
    out_path = os.path.join(OUT_DIR, f'sample_epoch_{epoch:03d}.png')
    vutils.save_image(grid, out_path)
    print(f"Saved sample grid to {out_path}")

print("Training finished.")


Starting training on this device cpu
Epoch 000 Batch 0000  Loss: 1.205566
Epoch 000 Batch 0010  Loss: 0.554684
Epoch 000 Batch 0020  Loss: 0.209574
Epoch 000 Batch 0030  Loss: 0.160273
Epoch 000 Batch 0040  Loss: 0.191099
Epoch 000 Batch 0050  Loss: 0.123958
Epoch 000 Batch 0060  Loss: 0.120723
Epoch 000 Batch 0070  Loss: 0.138267
Epoch 000 Batch 0080  Loss: 0.075937
Epoch 000 Batch 0090  Loss: 0.084412
Epoch 000 Batch 0100  Loss: 0.107960
Epoch 000 Batch 0110  Loss: 0.096695
Epoch 000 Batch 0120  Loss: 0.090893
Epoch 000 Batch 0130  Loss: 0.143969
Epoch 000 Batch 0140  Loss: 0.100229
Epoch 000 Batch 0150  Loss: 0.127983
Epoch 000 Batch 0160  Loss: 0.063832
Epoch 000 Batch 0170  Loss: 0.121347
Epoch 000 Batch 0180  Loss: 0.115272
Epoch 000 Batch 0190  Loss: 0.175867
Saved sample grid to ./samples/sample_epoch_000.png
Epoch 001 Batch 0000  Loss: 0.055919
Epoch 001 Batch 0010  Loss: 0.139013
Epoch 001 Batch 0020  Loss: 0.085872
Epoch 001 Batch 0030  Loss: 0.088676
Epoch 001 Batch 0040  L