In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [1]:
#Check imports
try:
  import torch
  import torch.nn as nn
  import pytorch_lightning as pl
  from torch.utils.data import DataLoader, random_split, Dataset
  import matplotlib.pyplot as plt
  import wandb
  from pytorch_lightning.loggers import WandbLogger
  from pytorch_lightning.callbacks import Callback
  from pytorch_msssim import ssim

except Exception as e:
  print(f"Exception = {e}")

In [3]:
import argparse
import datetime
import json
import math
import os
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, List

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import yaml

import utils
from multimae.criterion import MaskedMSELoss
from multimae.input_adapters import PatchedInputAdapter
from multimae.output_adapters import SpatialOutputAdapter
from utils import NativeScalerWithGradNormCount as NativeScaler
from utils import create_model
from utils.datasets_chloe import build_multimae_pretraining_dataset
from utils.optim_factory import create_optimizer
from utils.task_balancing import (NoWeightingStrategy,
                                  UncertaintyWeightingStrategy)

In [None]:
parser = argparse.ArgumentParser('MultiMAE pre-training script', add_help=False)
parser.add_argument('--batch_size', default=32, type=int,
                    help='Batch size per GPU (default: %(default)s)')
parser.add_argument('--epochs', default=20, type=int,
                    help='Number of epochs (default: %(default)s)')
parser.add_argument('--save_ckpt_freq', default=1, type=int,
                    help='Checkpoint saving frequency in epochs (default: %(default)s)')
# Task parameters
parser.add_argument('--in_domains', default='s1-s2', type=str,
                    help='Input domain names, separated by hyphen (default: %(default)s)')
parser.add_argument('--out_domains', default='s1-s2', type=str,
                    help='Output domain names, separated by hyphen (default: %(default)s)')

# Model parameters
parser.add_argument('--model', default='pretrain_multimae_base', type=str, metavar='MODEL',
                    help='Name of model to train (default: %(default)s)')
parser.add_argument('--num_encoded_tokens', default=784, type=int,
                    help='Number of tokens to randomly choose for encoder (default: %(default)s)')
parser.add_argument('--num_global_tokens', default=1, type=int,
                    help='Number of global tokens to add to encoder (default: %(default)s)')
parser.add_argument('--patch_size', default=16, type=int,
                    help='Base patch size for image-like modalities (default: %(default)s)')
parser.add_argument('--input_size', default=224, type=int,
                    help='Images input size for backbone (default: %(default)s)')
parser.add_argument('--alphas', type=float, default=0.3, 
                    help='Dirichlet alphas concentration parameter (default: %(default)s)')
parser.add_argument('--sample_tasks_uniformly', default=True, action='store_true',
                    help='Set to True/False to enable/disable uniform sampling over tasks to sample masks for.')
parser.add_argument('--decoder_use_task_queries', default=True, action='store_true',
                    help='Set to True/False to enable/disable adding of task-specific tokens to decoder query tokens')
parser.add_argument('--decoder_use_xattn', default=True, action='store_true',
                    help='Set to True/False to enable/disable decoder cross attention.')
parser.add_argument('--decoder_dim', default=256, type=int,
                    help='Token dimension inside the decoder layers (default: %(default)s)')
parser.add_argument('--decoder_depth', default=2, type=int,
                    help='Number of self-attention layers after the initial cross attention (default: %(default)s)')
parser.add_argument('--decoder_num_heads', default=8, type=int,
                    help='Number of attention heads in decoder (default: %(default)s)')
parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
                    help='Drop path rate (default: %(default)s)')
parser.add_argument('--loss_on_unmasked', default=False, action='store_true',
                    help='Set to True/False to enable/disable computing the loss on non-masked tokens')
parser.add_argument('--no_loss_on_unmasked', action='store_false', dest='loss_on_unmasked')
parser.set_defaults(loss_on_unmasked=False)
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: %(default)s)')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
                    help='Optimizer epsilon (default: %(default)s)')
parser.add_argument('--opt_betas', default=[0.9, 0.95], type=float, nargs='+', metavar='BETA',
                    help='Optimizer betas (default: %(default)s)')
parser.add_argument('--clip_grad', type=float, default=1, metavar='CLIPNORM',
                    help='Clip gradient norm (default: %(default)s)')
parser.add_argument('--skip_grad', type=float, default=None, metavar='SKIPNORM',
                    help='Skip update if gradient norm larger than threshold (default: %(default)s)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: %(default)s)')
parser.add_argument('--weight_decay', type=float, default=0.05,
                    help='Weight decay (default: %(default)s)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
    weight decay. We use a cosine schedule for WD.  (Set the same value as args.weight_decay to keep weight decay unchanged)""")
parser.add_argument('--decoder_decay', type=float, default=None, help='decoder weight decay')
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                    help='Base learning rate: absolute_lr = base_lr * total_batch_size / 256 (default: %(default)s)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
                    help='Warmup learning rate (default: %(default)s)')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                    help='Lower lr bound for cyclic schedulers that hit 0 (default: %(default)s)')
parser.add_argument('--task_balancer', type=str, default='none',
                    help='Task balancing scheme. One out of [uncertainty, none] (default: %(default)s)')
parser.add_argument('--balancer_lr_scale', type=float, default=1.0,
                    help='Task loss balancer LR scale (if used) (default: %(default)s)')
parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N',
                    help='Epochs to warmup LR, if scheduler supports (default: %(default)s)')
parser.add_argument('--warmup_steps', type=int, default=-0, metavar='N',
                    help='Epochs to warmup LR, if scheduler supports (default: %(default)s)')
parser.add_argument('--fp32_output_adapters', type=str, default='',
                    help='Tasks output adapters to compute in fp32 mode, separated by hyphen.')
# Augmentation parameters
parser.add_argument('--hflip', type=float, default=0.5,
                    help='Probability of horizontal flip (default: %(default)s)')
parser.add_argument('--train_interpolation', type=str, default='bicubic',
                    help='Training interpolation (random, bilinear, bicubic) (default: %(default)s)')
# Dataset parameters
parser.add_argument('--data_path', type=str, default=None,
                    help='(optional) base dir if your txt paths are relative.')
parser.add_argument(
    '--s1_txt',
    type=str,
    default="/work/mech-ai-scratch/bgekim/project/imputation/MultiMAE_NEW/MultiMAE/valid_list/nova/30m/pair_S1.txt",
    help="Path to modis txt file"
)

parser.add_argument(
    '--s2_txt',
    type=str,
    default="/work/mech-ai-scratch/bgekim/project/imputation/MultiMAE_NEW/MultiMAE/valid_list/nova/30m/pair_S2.txt",
    help="Path to s2 txt file"
)

parser.add_argument(
    '--all_domains',
    type=str,
    default='s1-s2',
    help='All domain names, separated by hyphen'
)


parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
# Misc.
parser.add_argument('--output_dir', default='/work/mech-ai-scratch/bgekim/project/imputation/MultiMAE_NEW/MultiMAE/result/s1-s2-new/',
                    help='Path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
                    help='Device to use for training / testing')
parser.add_argument('--seed', default=0, type=int, help='Random seed ')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--pin_mem', action='store_true',
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
                    help='')
parser.set_defaults(pin_mem=False)
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--no_find_unused_params', action='store_false', dest='find_unused_params')
parser.set_defaults(find_unused_params=True)
# Wandb logging
parser.add_argument('--log_wandb', default=False, action='store_true',
                    help='Log training and validation metrics to wandb')
parser.add_argument('--no_log_wandb', action='store_false', dest='log_wandb')
parser.set_defaults(log_wandb=False)
parser.add_argument('--wandb_project', default='MultiMAE-RGB', type=str,
                    help='Project name on wandb')
parser.add_argument('--wandb_entity', default='goeulkim', type=str,
                    help='User or team name on wandb')
parser.add_argument('--wandb_run_name', default='multimae-modis-s2', type=str,
                    help='Run name on wandb')
parser.add_argument('--show_user_warnings', default=False, action='store_true')
# Distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
                    help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')


args = parser.parse_args(args=[])
args.all_domains = args.all_domains.split('-')

In [5]:
# Get dataset
dataset = build_multimae_pretraining_dataset(args)

In [6]:
print(len(dataset))

118059


In [7]:
# 2. 비율 정의 (예: 70% train, 15% val, 15% test)
train_ratio, val_ratio, test_ratio = 0.7, 0.15, 0.15
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val  # 나머지

# 3. split
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(args.seed)  # reproducibility
)


# # 2. 비율 정의 (예: 80% train, 20% val, no test)
# train_ratio, val_ratio = 0.8, 0.2
# n_total = len(dataset)
# n_train = int(n_total * train_ratio)
# n_val = n_total - n_train 

# # 3. split
# train_dataset, val_dataset = random_split(
#     dataset, [n_train, n_val],
#     generator=torch.Generator().manual_seed(args.seed)  # reproducibility
# )

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))


82641
17708
17710


In [8]:
# 4. DataLoader 생성
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem
)

val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem
)

test_loader = DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem
)

In [9]:
batch = next(iter(val_loader))
print(batch.keys())

dict_keys(['s1', 's2'])


In [10]:
for key, tensor in batch.items():
  print(f"{key}: {tensor.shape}")

s1: torch.Size([32, 2, 224, 224])
s2: torch.Size([32, 12, 224, 224])


## From MultiMAE

## Pretraining

In [11]:
import torch
import torch.backends.cudnn as cudnn
from functools import partial
from multimae.criterion import MaskedMSELoss
from multimae.input_adapters import PatchedInputAdapter
from multimae.output_adapters import SpatialOutputAdapter
from utils import create_model, NativeScalerWithGradNormCount as NativeScaler
from utils.optim_factory import create_optimizer
import utils
import wandb

import torch
import torch.nn.functional as F

In [None]:
# # add contrastive loss
# def contrastive_loss(z1, z2, temperature=0.1):
#     """
#     Cross-modal contrastive loss (NT-Xent style).
    
#     Args:
#         z1, z2: [B, D] latent vectors from two modalities (e.g., S1 and S2)
#         temperature: scaling factor for logits
    
#     Returns:
#         Scalar contrastive loss
#     """
#     # Normalize embeddings
#     z1 = F.normalize(z1, dim=1)
#     z2 = F.normalize(z2, dim=1)

#     # Compute similarity matrix
#     logits = torch.matmul(z1, z2.T) / temperature  # [B, B]
#     labels = torch.arange(z1.size(0), device=z1.device)

#     # Cross entropy loss for both directions
#     loss_12 = F.cross_entropy(logits, labels)
#     loss_21 = F.cross_entropy(logits.T, labels)

#     return (loss_12 + loss_21) / 2

In [12]:
# ------------------------------
# Input Domain Configuration (수정: 채널 수 맞게!)
# ------------------------------
DOMAIN_CONF = {
    's1': {
        'channels': 2,
        'stride_level': 1,
        'input_adapter': partial(PatchedInputAdapter, num_channels=2),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=2),
        'loss': MaskedMSELoss, 
    },
    's2': {
        'channels': 12,
        'stride_level': 1,
        'input_adapter': partial(PatchedInputAdapter, num_channels=12),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=12),
        'loss': MaskedMSELoss,  
    },
}

In [13]:
# ------------------------------
# Model Builder
# ------------------------------
def get_model(in_domains, out_domains, patch_size=4, decoder_dim=256):
    input_adapters = {
        d: DOMAIN_CONF[d]['input_adapter'](stride_level=1, patch_size_full=patch_size)
        for d in in_domains
    }
    output_adapters = {
        d: DOMAIN_CONF[d]['output_adapter'](
            stride_level=1,
            patch_size_full=patch_size,
            dim_tokens=decoder_dim,
            depth=2,
            num_heads=8,
            use_task_queries=True,
            task=d,
            context_tasks=in_domains,
            use_xattn=True
        )
        for d in out_domains
    }
    return create_model(
        "pretrain_multimae_base",
        input_adapters=input_adapters,
        output_adapters=output_adapters,
        num_global_tokens=1,
        drop_path_rate=0.0
    )

In [14]:
# ------------------------------
# Training Loop
# ------------------------------
def train_one_epoch(model, train_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="train"):
    if split == "train":
        model.train()
    else:
        model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Epoch: [{epoch}]"

    for step, batch in enumerate(metric_logger.log_every(train_loader, 10, header)):
        tasks_dict = {t: ten.to(device, non_blocking=True) for t, ten in batch.items()}
        input_dict = {t: tasks_dict[t] for t in in_domains if t in tasks_dict}

        with torch.cuda.amp.autocast():
            preds, masks = model(input_dict, num_encoded_tokens=args.num_encoded_tokens) # original 
            # preds, masks, latents = model(input_dict, num_encoded_tokens=args.num_encoded_tokens, return_encoded=True) # to add contrastive loss

            task_losses = {}
            for task in out_domains:
                target = tasks_dict[task]
                task_losses[task] = tasks_loss_fn[task](preds[task].float(), target)

            loss_recon = sum(task_losses.values())

            # # Contrastive loss (예: modis vs s2)
            # z_s1 = latents['s1'].mean(dim=1)   # [B, D]
            # z_s2    = latents['s2'].mean(dim=1)      # [B, D]
            # loss_contrast = contrastive_loss(z_s1, z_s2, temperature=0.1)

            # loss = loss_recon + 0.1*loss_contrast
            loss = loss_recon
        
        if split == "train":
            optimizer.zero_grad()
            grad_norm = loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=args.clip_grad)
            torch.cuda.synchronize()
        else:
            grad_norm = 0.0  # validation은 grad 없음


        metric_logger.update(loss=loss.item(), grad_norm=grad_norm)
        for task, l in task_losses.items():
            metric_logger.update(**{f'{task}_loss': l.item()})
        
        wandb.log({
            "epoch": epoch,
            "step": step,
            f"{split}_loss_total": loss.item(),
            f"{split}_grad_norm": grad_norm,
            **{f"{split}_{task}_loss": l.item() for task, l in task_losses.items()}
        })

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return metric_logger


In [31]:
# ------------------------------
# Test Loop
# ------------------------------
def test_one_epoch(model, test_loader, tasks_loss_fn, device, epoch, in_domains, out_domains):


    # ----------------- W&B Init -----------------
    wandb.init(
        project="multimae-newdataset",
        entity="goeulkim"
        # id="gq9vrsal",   # 네가 알려줄 run id
        # resume="allow",  # 기존 run에 추가
    )

    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: [Epoch {epoch}]"

    with torch.no_grad():
        for step, batch in enumerate(metric_logger.log_every(test_loader, 10, header)):
            tasks_dict = {t: ten.to(device, non_blocking=True) for t, ten in batch.items()}
            input_dict = {t: tasks_dict[t] for t in in_domains if t in tasks_dict}

            with torch.cuda.amp.autocast():
                preds, masks = model(input_dict, num_encoded_tokens=args.num_encoded_tokens)

                task_losses = {}
                for task in out_domains:
                    target = tasks_dict[task]
                    task_losses[task] = tasks_loss_fn[task](preds[task].float(), target)

                loss = sum(task_losses.values())

            # grad 없음
            grad_norm = 0.0  

            # 기록
            metric_logger.update(loss=loss.item(), grad_norm=grad_norm)
            for task, l in task_losses.items():
                metric_logger.update(**{f'{task}_loss': l.item()})

            wandb.log({
                "epoch": epoch,
                "step": step,
                "test_loss_total": loss.item(),
                "test_grad_norm": grad_norm,
                **{f"test_{task}_loss": l.item() for task, l in task_losses.items()}
            })

    metric_logger.synchronize_between_processes()
    print("Test stats:", metric_logger)
    return metric_logger


In [None]:
# ------------------------------
# Main
# ------------------------------
def main(train_loader, val_loader):
    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------------- Config -----------------
    batch_size = args.batch_size
    epochs = args.epochs
    lr = args.blr
    patch_size = args.patch_size
    decoder_dim = args.decoder_dim

    in_domains = ['s1', 's2']
    out_domains = ['s1', 's2']

    wandb.init(
        project="multimae-newdataset",
        name="chloe_code",
        entity="goeulkim",
        config={
            "epochs":epochs,
            "lr":lr,
            "batch_size":batch_size,
            "patch_size":patch_size,
            "decoder_dim":decoder_dim
        }
    )

    # ----------------- Model -----------------
    model = get_model(in_domains, out_domains, patch_size, decoder_dim).to(device)
    
    # optimizer = create_optimizer(lr, model)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        eps=args.opt_eps,
        betas=tuple(args.opt_betas),
        weight_decay=args.weight_decay
        )

    loss_scaler = NativeScaler()

    tasks_loss_fn = {
        d: DOMAIN_CONF[d]['loss'](patch_size=patch_size, stride=1)
        for d in out_domains
    }

    torch.autograd.set_detect_anomaly(True)
    best_val_loss = float("inf")

    # ----------------- Training -----------------
    for epoch in range(epochs):
        train_one_epoch(model, train_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="train")
        # train_one_epoch(model, val_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="valid")

        # Validate
        val_stats = train_one_epoch(model, val_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="valid")

        # Validation loss 가져오기
        val_loss = val_stats.meters['loss'].global_avg



    # ----------------- Save -----------------
    torch.save(model.state_dict(), "pretrain_multimae.pth")
    # wandb.finish()
    print("✅ Pretrained model saved at pretrain_multimae.pth")


if __name__ == "__main__":
    # train_loader already here
    # from my_dataset import train_loader
    # main(train_loader)
    pass


In [16]:
main(train_loader, val_loader)

[34m[1mwandb[0m: Currently logged in as: [33mgoeul8604[0m ([33mgoeulkim[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast(enabled=False):


loss contains NaN: False
loss contains NaN: False
Epoch: [0]  [   0/2583]  eta: 10:58:03  loss: 3.7516 (3.7516)  grad_norm: 25.7891 (25.7891)  s1_loss: 2.3486 (2.3486)  s2_loss: 1.4030 (1.4030)  time: 15.2860  data: 4.0632  max mem: 10350
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
Epoch: [0]  [  10/2583]  eta: 1:19:23  loss: 1.4555 (2.0085)  grad_norm: 6.1470 (9.2881)  s1_loss: 0.5649 (1.0763)  s2_loss: 0.7832 (0.9322)  time: 1.8512  data: 0.4349  max mem: 11442
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
l

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
step,▇▅█▂▂▆▇▂▂▁▁▃▆▇▂▄▄█▁▁▃▄▆▃▃▂▃▄▅▆▁▂▅▆▂▃▆▂▂▁
train_grad_norm,██▄▄▃▃▃▂▃▄▃▂▂▂▂▃▂▁▁▂▂▂▁▁▂▂▂▂▁▁▁▁▁▂▁▂▁▁▁▁
train_loss_total,█▃▂▂▂▂▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_s1_loss,█▆▅▃▃▃▅▃▃▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_s2_loss,██▇▄▃▄▃▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
valid_grad_norm,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss_total,▇▇██▅▃▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁
valid_s1_loss,▇█▆▆▄▄▄▃▄▃▂▃▂▂▂▂▂▂▂▁▂▁▁▁▂▁▂▂▁▁▁▁▁▂▁▁▁▁▁▁
valid_s2_loss,█▅▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19.0
step,553.0
train_grad_norm,0.03564
train_loss_total,0.01566
train_s1_loss,0.00367
train_s2_loss,0.01199
valid_grad_norm,0.0
valid_loss_total,0.0157
valid_s1_loss,0.00429
valid_s2_loss,0.01142


✅ Pretrained model saved at pretrain_multimae.pth


In [33]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇███
test_grad_norm,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss_total,▅█▃█▆▄▅▂▄▃▅▇▄▄▅█▃▄▂▅▃▂▄▄▃▅▄▅▃▃▃▆▃▆▁▄▃▆▄▇
test_s1_loss,▄▂▂▄▄▄▅▂▃▄█▆▆▃▂▃▆▅▅▄▂▁▄▅▂▄▁▂▃▇▂▁▃▃█▃▂▄▅▂
test_s2_loss,▄▆▅▅▃▃▁▄▁█▆▆▃▃▄▄▄▃▄▆▃▄▇▄▃▂▂▃▂▃▃▅▁▄▃▆▁▃▅▃

0,1
epoch,-1.0
step,553.0
test_grad_norm,0.0
test_loss_total,0.01619
test_s1_loss,0.00607
test_s2_loss,0.01012


In [32]:
# ----------------- Config -----------------
batch_size = args.batch_size
epochs = args.epochs
lr = args.blr
patch_size = args.patch_size
decoder_dim = args.decoder_dim
in_domains = ['s1', 's2']
out_domains = ['s1', 's2']


# ----------------- Model ----------------- 
model = get_model(in_domains, out_domains, patch_size, decoder_dim).to(device)

# # ----------------- Load checkpoint -----------------
# ckpt_path = "/work/mech-ai-scratch/bgekim/project/imputation/MultiMAE_NEW/MultiMAE/model/best_model.pth"
# checkpoint = torch.load(ckpt_path, map_location=device)

# # case 1: if you saved with {"model": state_dict, "optimizer":..., "epoch":...}
# if "model" in checkpoint:
#     model.load_state_dict(checkpoint["model"], strict=False)
# else:
#     # case 2: if you saved only state_dict
#     model.load_state_dict(checkpoint, strict=False)


# ----------------- Load checkpoint -----------------
ckpt_path = "/work/mech-ai-scratch/bgekim/project/imputation/MultiMAE_NEW/MultiMAE/model/best_model.pth"
checkpoint = torch.load(ckpt_path, map_location=device)

# 모델 가중치 로드
model.load_state_dict(checkpoint["model_state_dict"], strict=True)

# 몇 epoch에서 저장했는지 확인
start_epoch = checkpoint["epoch"]
print(f"✅ Loaded model from epoch {start_epoch}")



tasks_loss_fn = {
    d: DOMAIN_CONF[d]['loss'](patch_size=patch_size, stride=1)
    for d in out_domains
}


# ----------------- Run Test -----------------
test_stats = test_one_epoch(
    model, 
    test_loader, 
    tasks_loss_fn, 
    device, 
    epoch=-1,          # testing only
    in_domains=in_domains, 
    out_domains=out_domains
)

print("Test finished. Averaged stats:", test_stats)

✅ Loaded model from epoch 20


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast(enabled=False):


loss contains NaN: False
loss contains NaN: False
Test: [Epoch -1]  [  0/554]  eta: 0:08:21  loss: 0.0167 (0.0167)  grad_norm: 0.0000 (0.0000)  s1_loss: 0.0035 (0.0035)  s2_loss: 0.0133 (0.0133)  time: 0.9058  data: 0.8285  max mem: 21065
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
loss contains NaN: False
Test: [Epoch -1]  [ 10/554]  eta: 0:02:02  loss: 0.0151 (0.0147)  grad_norm: 0.0000 (0.0000)  s1_loss: 0.0039 (0.0039)  s2_loss: 0.0105 (0.0107)  time: 0.2259  data: 0.1528  max mem: 21065
loss contains NaN: False
loss contains NaN: False
loss contains NaN: Fal

## dosnstream for cdl

In [None]:
import torch
import torch.backends.cudnn as cudnn
from functools import partial
from multimae.criterion import MaskedMSELoss
from multimae.input_adapters import PatchedInputAdapter
from multimae.output_adapters import SpatialOutputAdapter
from utils import create_model, NativeScalerWithGradNormCount as NativeScaler
from utils.optim_factory import create_optimizer
import utils
import wandb

# ------------------------------s
# Domain Configuration (harvest 포함)
# ------------------------------
DOMAIN_CONF = {
    'imagery': {
        'channels': 20, 
        'input_adapter': partial(PatchedInputAdapter, num_channels=20),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=20),
        'loss': MaskedMSELoss,
    },
    'application': {
        'channels': 10,  
        'input_adapter': partial(PatchedInputAdapter, num_channels=10),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=10),
        'loss': MaskedMSELoss,
    },
    'seeding': {
        'channels': 10, 
        'input_adapter': partial(PatchedInputAdapter, num_channels=10),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=10),
        'loss': MaskedMSELoss,
    },
    'soils': {
        'channels': 2,  
        'input_adapter': partial(PatchedInputAdapter, num_channels=2),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=2),
        'loss': MaskedMSELoss,
    },
    'harvest': {
        'channels': 1,   # target (yield)
        'input_adapter': partial(PatchedInputAdapter, num_channels=1),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=1),
        'loss': MaskedMSELoss,
    },
}

# ------------------------------
# Model Builder
# ------------------------------
def get_model(in_domains, out_domains, patch_size=4, decoder_dim=256):
    input_adapters = {
        d: DOMAIN_CONF[d]['input_adapter'](stride_level=1, patch_size_full=patch_size)
        for d in in_domains
    }
    output_adapters = {
        d: DOMAIN_CONF[d]['output_adapter'](
            stride_level=1,
            patch_size_full=patch_size,
            dim_tokens=decoder_dim,
            depth=2,
            num_heads=8,
            use_task_queries=True,
            task=d,
            context_tasks=in_domains,
            use_xattn=True
        )
        for d in out_domains
    }
    return create_model(
        "pretrain_multimae_base",
        input_adapters=input_adapters,
        output_adapters=output_adapters,
        num_global_tokens=1,
        drop_path_rate=0.0
    )

# ------------------------------
# Training Loop
# ------------------------------
def train_one_epoch(model, loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="train"):
    model.train() if split == "train" else model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"{split.capitalize()} Epoch: [{epoch}]"

    for step, batch in enumerate(metric_logger.log_every(loader, 10, header)):
        tasks_dict = {t: ten.to(device, non_blocking=True) for t, ten in batch.items()}
        input_dict = {t: tasks_dict[t] for t in in_domains if t in tasks_dict}

        with torch.cuda.amp.autocast():
            preds, masks = model(input_dict, num_encoded_tokens=args.num_encoded_tokens)
            task_losses = {}
            for task in out_domains:  # harvest
                target = tasks_dict[task]
                task_losses[task] = tasks_loss_fn[task](preds[task].float(), target)

            loss = sum(task_losses.values())

        if split == "train":
            optimizer.zero_grad()
            grad_norm = loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=args.clip_grad)
            torch.cuda.synchronize()
        else:
            grad_norm = 0.0

        metric_logger.update(loss=loss.item(), grad_norm=grad_norm)
        for task, l in task_losses.items():
            metric_logger.update(**{f'{split}_{task}_loss': l.item()})

        # ✅ W&B log
        wandb.log({
            "epoch": epoch,
            "step": step,
            f"{split}_loss_total": loss.item(),
            f"{split}_grad_norm": grad_norm,
            **{f"{split}_{task}_loss": l.item() for task, l in task_losses.items()}
        })

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    wandb.log({
      "epoch": epoch,
      f"{split}_loss_avg": metric_logger.meters["loss"].global_avg,
      **{f"{split}_{task}_loss_avg": metric_logger.meters[f"{split}_{task}_loss"].global_avg for task in out_domains}

    })

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

# ------------------------------
# Main
# ------------------------------
def main(train_loader, val_loader, test_loader):
    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------------- Config -----------------
    batch_size = 32
    epochs = 30
    lr = 5e-5
    patch_size = 4
    decoder_dim = 256

    in_domains = ['imagery', 'application', 'seeding', 'soils']
    out_domains = ['harvest']   # yield prediction

    # ----------------- W&B Init -----------------
    wandb.init(
        project="multimae-JD",
        entity="goeulkim",
        name="finetune_yield",
        config={
            "epochs": epochs,
            "lr": lr,
            "batch_size": batch_size,
            "patch_size": patch_size,
            "decoder_dim": decoder_dim,
            "in_domains": in_domains,
            "out_domains": out_domains,
        }
    )


    # ----------------- Model -----------------
    model = get_model(in_domains, out_domains, patch_size, decoder_dim).to(device)

    # ✅ pretrained weight 로드
    state_dict = torch.load("pretrain_multimae.pth", map_location=device)
    model.load_state_dict(state_dict, strict=False)  # harvest head는 새로 초기화됨

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    loss_scaler = NativeScaler()

    tasks_loss_fn = {
        d: DOMAIN_CONF[d]['loss'](patch_size=patch_size, stride=1)
        for d in out_domains
    }

    # ----------------- Training -----------------
    for epoch in range(epochs):
        train_one_epoch(model, train_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="train")
        train_one_epoch(model, val_loader, tasks_loss_fn, optimizer, device, epoch, loss_scaler, in_domains, out_domains, split="val")

    # ----------------- Test -----------------
    test_stats = train_one_epoch(model, test_loader, tasks_loss_fn, optimizer, device, epochs, loss_scaler, in_domains, out_domains, split="test")
    print("✅ Test performance:", test_stats)

    # ----------------- Save -----------------
    torch.save(model.state_dict(), "finetuned_multimae_harvest.pth")
    wandb.finish()
    print("✅ Fine-tuned model saved at finetuned_multimae_harvest.pth")

if __name__ == "__main__":
    # from my_dataset import train_loader, val_loader, test_loader
    # main(train_loader, val_loader, test_loader)
    pass


In [16]:
main(train_loader, val_loader, test_loader)

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁
step,▁
train_loss_total,▁
train_s1_loss,▁
train_s2_loss,▁

0,1
epoch,0.0
step,0.0
train_grad_norm,inf
train_loss_total,13.19023
train_s1_loss,6.73726
train_s2_loss,6.45297


  self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
  with torch.cuda.amp.autocast():


IndexError: list index out of range