In [22]:
import math
import os.path

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.amp import autocast, GradScaler
import mlflow
mlflow.set_tracking_uri('http://localhost:5000')
torch.set_float32_matmul_precision('medium')
mlflow.set_experiment("image-gen")

<Experiment: artifact_location='mlflow-artifacts:/527894522082058475', creation_time=1759781409349, experiment_id='527894522082058475', last_update_time=1759781409349, lifecycle_stage='active', name='image-gen', tags={'mlflow.experimentKind': 'custom_model_development'}>

In [23]:
image_dim = 128
device = 'cuda'
num_time_channels = 13
batch_size = 8
num_epochs = 10
show_every = 1000
do_load = True
params = {
    'image_dim': image_dim,
    'num_time_channels': num_time_channels,
    'batch_size': batch_size,
    'num_epochs': num_epochs,
    'device': device,
    'do_load': do_load,
}

In [24]:
dset = torchvision.datasets.Flowers102(root='data/flowers-102',
                                       download=True,
                                       transform=transforms.Compose(
                                           (transforms.Resize((image_dim, image_dim)),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
                                       )
                                       )
dldr = DataLoader(dset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
next(iter(dldr))[0].shape

torch.Size([8, 3, 128, 128])

In [25]:
def fourier_time_features(batch_n, shape, time_samples):

    # Vectorized computation
    i_vec = torch.linspace(0, num_time_channels - 1, num_time_channels)
    factor = torch.exp(math.log(10_000) * i_vec / num_time_channels).to(device)
    factor = factor.view(1, -1)
    
    # Compute all angles at once: [batch, num_time_channels]
    angles = time_samples.view(-1, 1) / factor
    
    # Create cos and sin features: [batch, num_time_channels]
    cos_features = torch.cos(angles)
    sin_features = torch.sin(angles)
    
    # Alternate cos and sin: even indices = cos, odd indices = sin
    time_features = torch.zeros((batch_n, num_time_channels, *shape), device=device)
    time_features[:, ::2, ...] = cos_features[:, ::2].view(batch_n, -1, 1, 1)
    time_features[:, 1::2, ...] = sin_features[:, 1::2].view(batch_n, -1, 1, 1)
    
    return time_features

fourier_time_features(batch_size, (image_dim, image_dim), torch.ones(batch_size, device=device)).shape

torch.Size([8, 13, 128, 128])

In [26]:
def time_coefficients(time_values):
    time_values = torch.clamp(time_values, 0, 1)
    a_t = 1-time_values
    b_t = time_values
    return a_t.view((-1, 1, 1, 1)), b_t.view((-1, 1, 1, 1))

In [27]:
class ResidualLinear(nn.Module):
    def __init__(self, dim):
        super(ResidualLinear, self).__init__()
        self.linear = nn.Sequential(nn.Linear(dim, 1024),
                                    nn.LeakyReLU(),
                                    nn.Linear(1024, dim))

    def forward(self, x):
        shape = x.shape
        out = x.view((shape[0], -1))
        out = self.linear(out)
        out = out.view(shape)
        return x+out

class UNet(nn.Module):
    def __init__(self, depth, channels, image_dims):
        super(UNet, self).__init__()
        self.conv_down = nn.Sequential(nn.Conv2d(in_channels=channels, out_channels=channels*2, kernel_size=3, stride=2, padding=1),
                                       nn.BatchNorm2d(channels*2),
                                       nn.ReLU(),
                                       nn.AdaptiveMaxPool2d((image_dims // 2, image_dims // 2)), )

        self.conv_up = nn.Sequential(nn.ConvTranspose2d(in_channels=channels*4, out_channels=channels, kernel_size=3, stride=2, padding=1),
                                     nn.BatchNorm2d(channels),
                                     nn.ReLU(),
                                     nn.AdaptiveMaxPool2d((image_dims, image_dims)), )

        if depth>1:
            self.sub_net = UNet(depth - 1, channels*2, image_dims // 2)
        else:
            self.sub_net = ResidualLinear(channels*2 * (image_dims // 2)**2)

    def forward(self, x):
        x = self.conv_down(x)
        out = self.sub_net(x)
        x = torch.cat((x, out), dim=1)
        x = self.conv_up(x)
        return x

In [28]:
flow_net = nn.Sequential(nn.BatchNorm2d(3+num_time_channels),UNet(6, 3+num_time_channels, image_dim), nn.Conv2d(3+num_time_channels, 3, 1)).to(device)
#flow_net.compile()
if os.path.exists('./flow_net.pth') and do_load:
    flow_net.load_state_dict(torch.load('./flow_net.pth'))
flow_net

Sequential(
  (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): UNet(
    (conv_down): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): AdaptiveMaxPool2d(output_size=(64, 64))
    )
    (conv_up): Sequential(
      (0): ConvTranspose2d(64, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): AdaptiveMaxPool2d(output_size=(128, 128))
    )
    (sub_net): UNet(
      (conv_down): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): AdaptiveMaxPool2d(output_size=(32, 32))
      )
      (conv_up): Sequential(
 

In [29]:
loss_fn = nn.SmoothL1Loss()
scaler = GradScaler()
optim = torch.optim.Adam(flow_net.parameters(), fused=True, lr=1e-3)

In [30]:
def sample_pic(step):
    n_steps = 10
    x_t = torch.randn((1, 3, image_dim, image_dim)).to(device)
    flow_net.eval()
    with torch.no_grad():
        for j in range(n_steps):
            time_values = j / n_steps * torch.ones(1).to(device)
            fourier_features = fourier_time_features(1, (image_dim, image_dim), time_values)
            x_t_in = torch.cat([x_t, fourier_features], dim=1)
            velocity_values = flow_net(x_t_in)
            time_value = (j+1)/n_steps
            var = (1-time_value)/time_value
            x_t += (2*velocity_values - x_t/time_value)/n_steps + \
                    math.sqrt(2*var/n_steps)*torch.randn_like(velocity_values)

    mlflow.log_image((x_t[0].permute(1, 2, 0).cpu().clamp(-1, 1).numpy() + 1) / 2, key=f"generated-image-{step:05d}")
    flow_net.train()

In [31]:
with mlflow.start_run():
    mlflow.log_params(params)
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        try:
            for i, batch in enumerate(dldr):
                x_1, _ = batch
                x_0 = torch.randn_like(x_1, device=device)
                x_1 = x_1.to(device)

                times = torch.rand(x_0.shape[0]).to(device)
                a, b = time_coefficients(times)

                x_t = a * x_0 + b * x_1
                fourier = fourier_time_features(x_t.shape[0], (image_dim, image_dim), times)

                x_in = torch.cat((x_t, fourier), dim=1)

                with autocast(device_type=device):
                    velocity = flow_net(x_in)
                    loss = loss_fn(velocity, x_1 - x_0)

                optim.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()

                mlflow.log_metrics({'scalars/loss': loss.item(),
                                    'scalars/velocity_estimate': (velocity**2).mean()**0.5,
                                    'scalars/velocity_real': ((x_1-x_0)**2).mean()**0.5}, step=len(dldr) * epoch + i)

                if i%(show_every//batch_size) == (show_every//batch_size) - 1:
                    sample_pic(len(dldr) * epoch + i)
        except KeyboardInterrupt as e:
            torch.save(flow_net.state_dict(), 'flow_net.pth')
            raise e

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
🏃 View run skillful-frog-547 at: http://localhost:5000/#/experiments/527894522082058475/runs/6ab4808c65b14347ad7261b8fda9138b
🧪 View experiment at: http://localhost:5000/#/experiments/527894522082058475


In [32]:
torch.save(flow_net.state_dict(), 'flow_net.pth')