In [None]:
from google.colab import drive
# Mount Google Drive
drive.mount('/content/drive')




Mounted at /content/drive


In [None]:
# Define the path to your zip file
zip_path = '/content/drive/My Drive/dataset_gc.zip'

# Define the destination folder where files will be unzipped
unzip_dest = '/content/dataset'

# Unzip the file
!unzip -q "$zip_path" -d "$unzip_dest"

# Check the extracted files
!ls "$unzip_dest"


dataset_gc  __MACOSX


In [None]:
!ls '/content/dataset/dataset_gc/dataset/Europeans'



10_0_0_20170103233459275.jpg   26_0_0_20170105163435235.jpg  52_1_0_20170109221057047.jpg
10_0_0_20170110220235233.jpg   26_0_0_20170105183712607.jpg  52_1_0_20170110122416544.jpg
10_0_0_20170110220251986.jpg   26_0_0_20170108235818665.jpg  52_1_0_20170110122705423.jpg
10_0_0_20170110220541850.jpg   26_1_0_20170103180649464.jpg  52_1_0_20170110123807882.jpg
10_0_0_20170110220548521.jpg   26_1_0_20170103181112840.jpg  52_1_0_20170110152849848.jpg
10_0_0_20170110224253445.jpg   26_1_0_20170103181710200.jpg  52_1_0_20170110153654447.jpg
10_0_0_20170110224255796.jpg   26_1_0_20170103181852617.jpg  53_0_0_20170104184207950.jpg
10_0_0_20170110224416035.jpg   26_1_0_20170104022424245.jpg  53_0_0_20170104184411830.jpg
10_0_0_20170110224524253.jpg   26_1_0_20170104165749289.jpg  53_0_0_20170104210640915.jpg
10_0_0_20170110224725285.jpg   26_1_0_20170105183935352.jpg  53_0_0_20170104212413221.jpg
10_0_0_20170110225035898.jpg   26_1_3_20161220222108978.jpg  53_0_0_20170109012709495.jpg
10_0_0_201

In [None]:
import os
import random
import logging
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, ConcatDataset
from torchvision.datasets import ImageFolder
from fastprogress import progress_bar
import wandb
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast, GradScaler

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set environment variable for memory allocation
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'



# Utility Functions
def set_seed(s, reproducible=False):
    """Set random seed for reproducibility"""
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    np.random.seed(s%(2**32-1))
    random.seed(s)
    if reproducible:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def mk_folders(run_name):
    """Create necessary folders for saving models and results"""
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

def plot_images(images):
    """Plot a grid of images"""
    import matplotlib.pyplot as plt
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()

def center_crop_square(image):
    """Ensure faces are better centered"""
    size = min(image.size)
    left = (image.size[0] - size) // 2
    top = (image.size[1] - size) // 2
    right = left + size
    bottom = top + size
    return image.crop((left, top, right, bottom))

# Model Modules
def one_param(m):
    """Get model first parameter"""
    return next(iter(m.parameters()))

class EMA:
    """Exponential Moving Average for model parameters"""
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-1]
        x = x.view(-1, self.channels, size * size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class UNet_conditional(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)
        self.dropout = nn.Dropout(0.1)

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

        # Enable gradient checkpointing
        self.use_checkpoint = True

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t, y=None):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        if self.use_checkpoint and self.training:
            from torch.utils.checkpoint import checkpoint
            x = self.dropout(x)
            # Encoding path with checkpointing
            x1 = checkpoint(self.inc, x)
            x2 = checkpoint(lambda x, t: self.down1(x, t), x1, t)
            x2 = checkpoint(self.sa1, x2)
            x3 = checkpoint(lambda x, t: self.down2(x, t), x2, t)
            x3 = checkpoint(self.sa2, x3)
            x4 = checkpoint(lambda x, t: self.down3(x, t), x3, t)
            x4 = checkpoint(self.sa3, x4)

            # Bottleneck with checkpointing
            x4 = checkpoint(self.bot1, x4)
            x4 = checkpoint(self.bot2, x4)
            x4 = checkpoint(self.bot3, x4)

            # Decoding path with checkpointing
            x = checkpoint(lambda x, skip, t: self.up1(x, skip, t), x4, x3, t)
            x = checkpoint(self.sa4, x)
            x = checkpoint(lambda x, skip, t: self.up2(x, skip, t), x, x2, t)
            x = checkpoint(self.sa5, x)
            x = checkpoint(lambda x, skip, t: self.up3(x, skip, t), x, x1, t)
            x = checkpoint(self.sa6, x)
            output = checkpoint(self.outc, x)
        else:
            # Regular forward pass without checkpointing
            x1 = self.inc(x)
            x2 = self.down1(x1, t)
            x2 = self.sa1(x2)
            x3 = self.down2(x2, t)
            x3 = self.sa2(x3)
            x4 = self.down3(x3, t)
            x4 = self.sa3(x4)

            x4 = self.bot1(x4)
            x4 = self.bot2(x4)
            x4 = self.bot3(x4)

            x = self.up1(x4, x3, t)
            x = self.sa4(x)
            x = self.up2(x, x2, t)
            x = self.sa5(x)
            x = self.up3(x, x1, t)
            x = self.sa6(x)
            output = self.outc(x)

        return output

# Configuration Class
class Config:
    def __init__(self):
        self.run_name = "facial_diffusion"
        self.epochs = 200
        self.noise_steps = 500
        self.seed = 42
        self.batch_size = 10
        self.img_size = 64
        self.num_classes = 3  # Europeans, Indians, Orientals
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.slice_size = 1
        self.do_validation = True
        self.fp16 = True
        self.fp16_scaler = True
        self.log_every_epoch = 10
        self.num_workers = 2
        self.lr = 2e-4
        self.warmup_steps = 100

        # Dataset paths
        self.dataset_paths = {
            'Europeans': '/content/dataset/dataset_gc/dataset/Europeans',
            'Indians': '/content/dataset/dataset_gc/dataset/Indians',
            'Orientals': '/content/dataset/dataset_gc/dataset/Orientals'
        }

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform
        self.image_files = [f for f in os.listdir(path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.path, image_file)
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image


# Main Diffusion Class
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, num_classes=3, c_in=3, c_out=3, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        self.img_size = img_size
        self.model = UNet_conditional(c_in, c_out, num_classes=num_classes, device=device).to(device)
        self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
        self.device = device
        self.c_in = c_in
        self.num_classes = num_classes

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x, requires_grad=True)  # Add requires_grad=True
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def prepare(self, args):
        train_transforms = T.Compose([
        T.Resize(args.img_size),
        T.RandomHorizontalFlip(p=0.3),  # Add subtle flipping
        T.ColorJitter(brightness=0.1, contrast=0.1),  # Slight color variation
        T.RandomRotation(degrees=5),  # Minor rotation
        T.CenterCrop(args.img_size),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

        # Create datasets for each ethnic group
        datasets = []
        for idx, (group, path) in enumerate(args.dataset_paths.items()):
            dataset = ImageDataset(path, transform=train_transforms)
            # Create labels tensor for this dataset
            labels = torch.full((len(dataset),), idx, dtype=torch.long)
            # Wrap dataset and labels together
            dataset = torch.utils.data.TensorDataset(
                torch.stack([d for d in dataset]),
                labels
            )
            datasets.append(dataset)

        # Combine all datasets
        combined_dataset = ConcatDataset(datasets)

        # Create dataloader
        self.train_dataloader = DataLoader(
            combined_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers
        )

       # Initialize optimizer and other components
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=args.lr,
            steps_per_epoch=len(self.train_dataloader),
            epochs=args.epochs
        )
        self.mse = nn.MSELoss()
        self.ema = EMA(0.995)

        self.fp16 = args.fp16

        # Initialize scaler for fp16
        if self.fp16:
            self.scaler = GradScaler(enabled=True)
        else:
            self.scaler = GradScaler(enabled=False)

    @torch.inference_mode()
    def sample(self, use_ema, labels, cfg_scale=3):
        model = self.ema_model if use_ema else self.model
        n = len(labels)
        model.eval()

        with torch.inference_mode():
            x = torch.randn((n, self.c_in, self.img_size, self.img_size)).to(self.device)
            temp = 0.8  # Slightly lower temperature for better details
            x = x * temp

            for i in progress_bar(reversed(range(1, self.noise_steps)), total=self.noise_steps-1, leave=False):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t, labels)

                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)

                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]

                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)

                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x


    def train_step(self, loss):
        if self.fp16:
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

        self.ema.step_ema(self.ema_model, self.model)
        self.scheduler.step()

    def one_epoch(self, train=True):
        avg_loss = 0.
        if train:
            self.model.train()
        else:
            self.model.eval()

        # Remove this line since we already have self.scaler
        # scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

        pbar = progress_bar(self.train_dataloader, leave=False)
        for i, (images, labels) in enumerate(pbar):
            if i % 10 == 0:
                torch.cuda.empty_cache()

            images = images.to(self.device)
            labels = labels.to(self.device)

            # Modified training step
            self.optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=self.fp16):
                t = self.sample_timesteps(images.shape[0]).to(self.device)
                x_t, noise = self.noise_images(images, t)

                if np.random.random() < 0.1:
                    labels = None

                predicted_noise = self.model(x_t, t, labels)
                loss = self.mse(noise, predicted_noise)

            if train:
                if self.fp16:
                    # Use self.scaler instead of scaler
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()

                self.ema.step_ema(self.ema_model, self.model)
                self.scheduler.step()

                wandb.log({
                    "train_mse": loss.item(),
                    "learning_rate": self.scheduler.get_last_lr()[0]
                })

            avg_loss += loss.item()
            pbar.comment = f"MSE={loss.item():2.3f}"

            # Clear memory
            del predicted_noise, x_t, noise
            torch.cuda.empty_cache()

        return avg_loss / len(self.train_dataloader)

    def validate(self, epoch):
        if epoch % 10 == 0:
            self.model.eval()
            with torch.no_grad():
                # Generate sample images
                labels = torch.tensor([0, 1, 2] * 2).to(self.device)
                samples = self.sample(use_ema=True, labels=labels)

                # Log images
                wandb.log({
                    f"validation_samples_epoch_{epoch}": [
                        wandb.Image(img.permute(1,2,0).cpu().numpy())
                        for img in samples
                    ]
                })

    def fit(self, args):
        try:
            for epoch in progress_bar(range(args.epochs), total=args.epochs, leave=True):
                logging.info(f"Starting epoch {epoch}:")
                train_loss = self.one_epoch(train=True)

                wandb.log({
                    "epoch": epoch,
                    "train_loss": train_loss,
                })

                self.validate(epoch)

        except RuntimeError as e:
            if "out of memory" in str(e):
                print('| WARNING: out of memory')
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                raise e
    def visualize_batch(self, images, labels):
        """Debug function to visualize training data"""
        plt.figure(figsize=(20, 4))
        for i in range(min(10, len(images))):
            plt.subplot(1, 10, i+1)
            img = images[i].cpu().permute(1, 2, 0)
            img = (img * 0.5 + 0.5).numpy()  # Denormalize
            plt.imshow(img)
            plt.title(f'Label: {labels[i].item()}')
            plt.axis('off')
        plt.show()

def main():
    config = Config()
    set_seed(config.seed)

    diffusion = Diffusion(
        noise_steps=config.noise_steps,
        img_size=config.img_size,
        num_classes=config.num_classes
    )

    with wandb.init(project="facial_diffusion", config=vars(config)):
        diffusion.prepare(config)
        diffusion.fit(config)

        # Generate images for specific combinations
        combinations = [
            # Single group generations (for baseline)
            torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),  # Europeans
            torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),  # Indians
            torch.tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),  # Orientals

            # Binary Combinations
            torch.tensor([0, 0, 1, 1, 2, 2, 0, 1, 2, 0]),  # Orientals/Indians
            torch.tensor([0, 0, 2, 2, 1, 1, 0, 2, 1, 0]),  # Orientals/Europeans
            torch.tensor([1, 1, 0, 0, 2, 2, 1, 0, 2, 1]),  # Indians/Europeans

            # Tri-group Combination
            torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])   # Orientals/Indians/Europeans
        ]

        combination_names = [
            "Europeans",
            "Indians",
            "Orientals",
            "Orientals_Indians",
            "Orientals_Europeans",
            "Indians_Europeans",
            "Mixed_All"
        ]

        # Sample and save images for each combination
        for idx, labels in enumerate(combinations):
            labels = labels.to(diffusion.device)
            sampled_images = diffusion.sample(use_ema=True, labels=labels)

            # Plot images
            plt.figure(figsize=(20, 4))
            plt.title(f"Generated Faces - {combination_names[idx]}")
            for i in range(10):
                plt.subplot(1, 10, i+1)
                plt.imshow(sampled_images[i].permute(1,2,0).cpu().numpy())
                plt.axis('off')
            plt.tight_layout()
            plt.savefig(f'generated_faces_{combination_names[idx]}.png')
            plt.close()

            # Log images to wandb
            wandb.log({
                f"generated_images_{combination_names[idx]}": [
                    wandb.Image(img.permute(1,2,0).cpu().numpy())
                    for img in sampled_images
                ]
            })

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  self.scaler = GradScaler(enabled=True)


  with torch.cuda.amp.autocast(enabled=self.fp16):
  return fn(*args, **kwargs)


  with torch.cuda.amp.autocast(enabled=self.fp16):
  return fn(*args, **kwargs)


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▅▅▅▅▅▆▆▆▆▇▇▇▇██
learning_rate,▁▁▁▂▂▆▆▇███████▇▇▇▇▇▆▆▆▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁
train_loss,█▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mse,▅█▂▃▁▃▃▂▃▂▁▂▂▃▂▄▂▁▁▂▁▂▁▁▂▂▃▁▃▂▁▂▁▁▁▅▁▂▁▁

0,1
epoch,199.0
learning_rate,0.0
train_loss,0.13296
train_mse,0.14143
