In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pywt
import numpy as np
from einops import rearrange, repeat

# If you have not installed lion_pytorch yet, install it via:
# pip install lion-pytorch
from lion_pytorch import Lion

import sys
sys.path.append("..")

from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

# ------------------- Data Loading ------------------- 
image_size = 224

tiny_transform_train = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tiny_transform_val = transforms.Compose([
    transforms.Resize(72),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tiny_transform_test = transforms.Compose([
    transforms.Resize(72),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
    data_dir='../datasets',
    transform_train=tiny_transform_train,
    transform_val=tiny_transform_val,
    transform_test=tiny_transform_test,
    batch_size=64,
    image_size=image_size
)

# ---------------------------------------------------------------------
# Two wavelet modules:
#   1) LearnableDWTDownsample => stride=2, kernel_size=2 (used for stem & pools)
#   2) LearnableDWTNoDownsample => stride=1, kernel_size=3, padding=1 (used in blocks)
# ---------------------------------------------------------------------

class LearnableDWTDownsample(nn.Module):
    """
    Standard wavelet downsampling with kernel=2, stride=2, groups=in_channels.
    This halves spatial dimensions.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels * 4,
            kernel_size=2,
            stride=2,
            groups=in_channels
        )
        self.initialize_weights()

    def initialize_weights(self):
        """Use standard Haar wavelet for downsampling."""
        haar_low = torch.tensor([1.0, 1.0]) / np.sqrt(2)
        haar_high = torch.tensor([-1.0, 1.0]) / np.sqrt(2)

        with torch.no_grad():
            self.conv.weight.zero_()
            for c in range(self.in_channels):
                for subband in range(4):
                    oc = c * 4 + subband
                    if subband == 0:      # LL
                        w0, w1 = haar_low[0], haar_low[1]
                    elif subband == 1:    # LH
                        w0, w1 = haar_low[0], haar_high[1]
                    elif subband == 2:    # HL
                        w0, w1 = haar_high[0], haar_low[1]
                    else:                 # HH
                        w0, w1 = haar_high[0], haar_high[1]

                    self.conv.weight[oc, 0, 0, 0] = w0
                    self.conv.weight[oc, 0, 0, 1] = w1
                    self.conv.weight[oc, 0, 1, 0] = w0
                    self.conv.weight[oc, 0, 1, 1] = w1

    def forward(self, x):
        """
        x: (B, C, H, W)
        returns (B, C, H/2, W/2, 4)
        """
        x = self.conv(x)  # => (B, 4*C, H/2, W/2)
        return rearrange(x, 'b (c g) h w -> b c h w g', g=4)


class LearnableDWTNoDownsample(nn.Module):
    """
    Wavelet-like transform but WITHOUT reducing spatial resolution.
    We use kernel=3, stride=1, padding=1 to preserve (H, W).
    Groups = in_channels => each channel has a set of 4 wavelet filters.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels * 4,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=in_channels
        )
        self.initialize_weights()

    def initialize_weights(self):
        """Place the 2x2 Haar kernel in the top-left corner, zero elsewhere."""
        haar_low = torch.tensor([1.0, 1.0]) / np.sqrt(2)
        haar_high = torch.tensor([-1.0, 1.0]) / np.sqrt(2)

        with torch.no_grad():
            self.conv.weight.zero_()
            # shape => (4*in_channels, 1, 3, 3)
            for c in range(self.in_channels):
                for subband in range(4):
                    oc = c * 4 + subband
                    if subband == 0:      # LL
                        w0, w1 = haar_low[0], haar_low[1]
                    elif subband == 1:    # LH
                        w0, w1 = haar_low[0], haar_high[1]
                    elif subband == 2:    # HL
                        w0, w1 = haar_high[0], haar_low[1]
                    else:                 # HH
                        w0, w1 = haar_high[0], haar_high[1]

                    # We store them in the top-left 2x2 of the 3x3 kernel
                    self.conv.weight[oc, 0, 0, 0] = w0
                    self.conv.weight[oc, 0, 0, 1] = w1
                    self.conv.weight[oc, 0, 1, 0] = w0
                    self.conv.weight[oc, 0, 1, 1] = w1

    def forward(self, x):
        """
        x: (B, C, H, W)
        returns (B, C, H, W, 4) same spatial size, but 4 wavelet subbands.
        """
        x = self.conv(x)  # => (B, 4*C, H, W)
        return rearrange(x, 'b (c g) h w -> b c h w g', g=4)


# ------------------- WaveletAttention -------------------
class WaveletAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.hf_gate = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        """
        x shape: (B, N, C)
        """
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        # => (3, B, num_heads, N, C//num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        # High-frequency gating
        hf_weight = torch.sigmoid(self.hf_gate)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x * hf_weight)
        return x


# ------------------- Model Architecture -------------------
class WOCBlock(nn.Module):
    """
    A block that does a wavelet transform *without downsampling*,
    merges LL and HF, then attention + FFN.
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WaveletAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Wavelet transform that does NOT reduce spatial size
        self.dwt_no_down = LearnableDWTNoDownsample(in_channels=dim)
        self.hf_gate = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        """
        x: (B, N, C).  Suppose N = H*W is square.
        """
        B, N, C = x.shape
        side = int(N**0.5)  # assume square

        # Reshape tokens to (B, C, H, W)
        x_img = rearrange(x, 'b (h w) c -> b c h w', h=side, w=side)

        # Wavelet => shape (B, C, H, W, 4)
        coeffs = self.dwt_no_down(x_img)
        ll, lh, hl, hh = torch.chunk(coeffs, 4, dim=-1)  # each => (B, C, H, W, 1)

        # Sum HF subbands => (B, C, H, W, 1)
        hf_combined = lh + hl + hh

        # Convert to tokens
        ll_tokens = rearrange(ll, 'b c h w 1 -> b (h w) c')
        hf_tokens = rearrange(hf_combined, 'b c h w 1 -> b (h w) c')
        hf_tokens = hf_tokens * torch.sigmoid(self.hf_gate)

        # Merge LL + HF => same number of tokens => (B, N, C)
        x = ll_tokens + hf_tokens

        # Attention + FFN
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class WOCSwin(nn.Module):
    """
    A wavelet-based hierarchical model:
      - Wavelet downsample "stem" => halved spatial resolution
      - A few 'stages', each containing multiple WOCBlocks that *do not* reduce resolution
      - A wavelet downsample "pool" between stages
      - Finally, classification head
    """
    def __init__(self, num_classes=200):
        super().__init__()
        # 1) Stem => wavelet downsample from 3 channels to 3*4, halving spatial size
        self.stem = LearnableDWTDownsample(in_channels=3)

        # Then flatten => tokens & embed
        self.embed = nn.Sequential(
            nn.Linear(12, 96),
            nn.GELU()
        )
        
        # 2) Stages: each stage has blocks that keep the same resolution
        # and at the end we do wavelet "pool" to half resolution
        self.stages = nn.ModuleList([
            nn.Sequential(*[WOCBlock(dim=96,  num_heads=4) for _ in range(2)]),
            nn.Sequential(*[WOCBlock(dim=192, num_heads=4) for _ in range(2)]),
            nn.Sequential(*[WOCBlock(dim=384, num_heads=4) for _ in range(6)]),
            nn.Sequential(*[WOCBlock(dim=768, num_heads=4) for _ in range(2)])
        ])
        
        # 3) Pools => wavelet downsample between stages
        self.pools = nn.ModuleList([
            nn.Sequential(
                LearnableDWTDownsample(in_channels=96),
                nn.Linear(96 * 4, 192)
            ),
            nn.Sequential(
                LearnableDWTDownsample(in_channels=192),
                nn.Linear(192 * 4, 384)
            ),
            nn.Sequential(
                LearnableDWTDownsample(in_channels=384),
                nn.Linear(384 * 4, 768)
            )
        ])
        
        # 4) Classifier head
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(768, num_classes)
        )

    def forward(self, x):
        """
        x: (B, 3, H, W), e.g. H=W=64 for TinyImageNet
        """
        # -- Stem -- 
        # out => (B, 3, H/2, W/2, 4), rearr => (B, (H/2)*(W/2), 12)
        x = self.stem(x)  
        x = rearrange(x, 'b c h w g -> b (h w) (c g)')  # => e.g. (B, 32*32=1024, 12)
        x = self.embed(x)  # => (B, 1024, 96)

        # -- Stages 0..2, each followed by wavelet pool --
        for stage, pool_layer in zip(self.stages[:-1], self.pools):
            x = stage(x)  # blocks keep resolution (no stride)
            x = wavelet_pool(x, pool_layer)

        # -- Final stage (no pool after) --
        x = self.stages[-1](x)  # => dimension 768

        # -- Head --
        x = x.transpose(1, 2)  # => (B, 768, N)
        return self.head(x)


def wavelet_pool(x, pool_module):
    """
    pool_module = [LearnableDWTDownsample(in_channels=X), Linear(X*4, Y)]
    Steps:
      - reshape tokens => (B, X, H, W)
      - wavelet downsample => (B, X, H/2, W/2, 4)
      - flatten => (B, (H/2)*(W/2), X*4)
      - linear => (B, (H/2)*(W/2), Y)
    """
    B, N, C = x.shape
    side = int(N**0.5)  # must be square => no mismatch
    x_img = rearrange(x, 'b (h w) c -> b c h w', h=side, w=side)

    dwt_layer, linear_layer = pool_module
    y = dwt_layer(x_img)  # => (B, X, H/2, W/2, 4)
    y = rearrange(y, 'b c h w g -> b (h w) (c g)')  # => (B, (H/2)*(W/2), X*4)
    y = linear_layer(y)   # => (B, new_N, new_dim)
    return y

# ------------------- Training Loop -------------------
class ProgressiveTrainer:
    def __init__(self, model, train_loader, val_loader, device):
        self.device = device
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Lion optimizer
        self.optimizer = Lion(self.model.parameters(), lr=3e-4, betas=(0.95, 0.98))
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=300)
        self.criterion = nn.CrossEntropyLoss()
        
    def mix_features(self, x, epoch):
        """Progressive wavelet/noise mixing as data augmentation."""
        if epoch < 50:
            ratio = 0.2 * (epoch / 50)
        elif epoch < 150:
            ratio = 0.2 + 0.8 * ((epoch - 50) / 100)
        else:
            ratio = 1.0 + 0.5 * ((epoch - 150) / 150)
        
        # Apply high-frequency noise
        x = x + ratio * torch.randn_like(x) * 0.3
        return x

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        correct = 0
        
        for images, labels in self.train_loader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            
            # Progressive mixing
            images = self.mix_features(images, epoch)
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            
        return total_loss / len(self.train_loader), correct / len(self.train_loader.dataset)

    def validate(self):
        self.model.eval()
        total_loss = 0
        correct = 0
        
        with torch.no_grad():
            for images, labels in self.val_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                
        return total_loss / len(self.val_loader), correct / len(self.val_loader.dataset)

    def train(self, epochs=300):
        best_acc = 0.0

        for epoch in range(epochs):
            train_loss, train_acc = self.train_epoch(epoch)
            val_loss, val_acc = self.validate()
            self.scheduler.step()

            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc*100:.2f}%")
            print(f"Val   Loss: {val_loss:.4f} | Acc: {val_acc*100:.2f}%\n")

            # Save best model
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(self.model.state_dict(), "woc_swin_best.pth")


# ------------------- Main Execution -------------------
if __name__ == "__main__":
    # Detect device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Initialize model
    model = WOCSwin(num_classes=200)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params}")
    
    # Initialize trainer
    trainer = ProgressiveTrainer(model, train_loader, val_loader, device)
    
    # Train the model
    trainer.train(epochs=300)


Using device: cpu
