# ARSIVAE DATALOADER

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.feature import graycomatrix, graycoprops
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score, mean_absolute_error
import seaborn as sns

class CTDataset_ARSIVAE(Dataset):
    """Dataset with all 14 physics attributes for AR-SI-VAE"""
    
    def __init__(self, csv_path, compute_on_fly=True):
        self.df = pd.read_csv(csv_path)
        self.compute_on_fly = compute_on_fly
        self.has_precomputed = self._check_precomputed_features()
        
        if not self.has_precomputed and not compute_on_fly:
            raise ValueError("CSV missing precomputed features and compute_on_fly=False")
        
        print(f"\n{'='*70}")
        print(f"AR-SI-VAE Dataset Loaded")
        print(f"{'='*70}")
        print(f"Total samples: {len(self):,}")
        print(f"  COVID:  {len(self.df[self.df['label']==1]):,}")
        print(f"  Normal: {len(self.df[self.df['label']==0]):,}")
        print(f"Physics: {'On-the-fly' if compute_on_fly else 'Precomputed'}")
        print(f"{'='*70}\n")
    
    def _check_precomputed_features(self):
        required = ['mean_HU', 'HU_std', 'HU_p10', 'HU_p25', 'HU_p50', 'HU_p75', 'HU_p90',
                   'mask_area_pixels', 'mask_fraction', 'grad_mean', 'grad_std',
                   'contrast', 'homogeneity', 'entropy']
        return all(col in self.df.columns for col in required)
    
    def __len__(self):
        return len(self.df)
    
    def _compute_hu_features(self, ct, mask):
        lung_pixels = mask > 0.5
        hu_values = ct[lung_pixels]
        
        if len(hu_values) == 0:
            return {k: 0.0 for k in ['mean_HU', 'HU_std', 'HU_p10', 'HU_p25', 
                                     'HU_p50', 'HU_p75', 'HU_p90']}
        
        return {
            'mean_HU': float(np.mean(hu_values)),
            'HU_std': float(np.std(hu_values)),
            'HU_p10': float(np.percentile(hu_values, 10)),
            'HU_p25': float(np.percentile(hu_values, 25)),
            'HU_p50': float(np.percentile(hu_values, 50)),
            'HU_p75': float(np.percentile(hu_values, 75)),
            'HU_p90': float(np.percentile(hu_values, 90))
        }
    
    def _compute_shape_features(self, mask, image_size=512*512):
        mask_area = float(np.sum(mask > 0.5))
        return {
            'mask_area_pixels': mask_area,
            'mask_fraction': mask_area / image_size
        }
    
    def _compute_gradient_features(self, ct, mask):
        grad_y, grad_x = np.gradient(ct)
        grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
        lung_pixels = mask > 0.5
        grad_in_lung = grad_magnitude[lung_pixels]
        
        if len(grad_in_lung) == 0:
            return {'grad_mean': 0.0, 'grad_std': 0.0}
        
        return {
            'grad_mean': float(np.mean(grad_in_lung)),
            'grad_std': float(np.std(grad_in_lung))
        }
    
    def _compute_texture_features(self, ct, mask):
        lung_pixels = mask > 0.5
        if lung_pixels.sum() == 0:
            return {'contrast': 0.0, 'homogeneity': 1.0, 'entropy': 0.0}
        
        ct_masked = ct.copy()
        ct_masked[~lung_pixels] = ct_masked[lung_pixels].min()
        ct_min = ct_masked[lung_pixels].min()
        ct_max = ct_masked[lung_pixels].max()
        
        if ct_max == ct_min:
            return {'contrast': 0.0, 'homogeneity': 1.0, 'entropy': 0.0}
        
        ct_normalized = ((ct_masked - ct_min) / (ct_max - ct_min) * 255).astype(np.uint8)
        
        try:
            glcm = graycomatrix(ct_normalized, distances=[1], 
                              angles=[0, np.pi/4, np.pi/2, 3*np.pi/4],
                              levels=256, symmetric=True, normed=True)
            
            contrast = graycoprops(glcm, 'contrast').mean()
            homogeneity = graycoprops(glcm, 'homogeneity').mean()
            glcm_norm = glcm / (glcm.sum() + 1e-10)
            entropy = -np.sum(glcm_norm * np.log2(glcm_norm + 1e-10))
            
            return {
                'contrast': float(contrast),
                'homogeneity': float(homogeneity),
                'entropy': float(entropy)
            }
        except:
            return {'contrast': 0.0, 'homogeneity': 1.0, 'entropy': 0.0}
    
    def _compute_all_physics(self, ct, mu, mask):
        ct_hu = ct * 700 - 300  # Denormalize to HU
        
        hu_feat = self._compute_hu_features(ct_hu, mask)
        shape_feat = self._compute_shape_features(mask)
        grad_feat = self._compute_gradient_features(ct, mask)
        texture_feat = self._compute_texture_features(ct, mask)
        
        attributes = np.array([
            hu_feat['mean_HU'], hu_feat['HU_std'], hu_feat['HU_p10'], hu_feat['HU_p25'],
            hu_feat['HU_p50'], hu_feat['HU_p75'], hu_feat['HU_p90'],
            shape_feat['mask_area_pixels'], shape_feat['mask_fraction'],
            grad_feat['grad_mean'], grad_feat['grad_std'],
            texture_feat['contrast'], texture_feat['homogeneity'], texture_feat['entropy']
        ], dtype=np.float32)
        
        return attributes
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        ct = np.load(row['ct_path'])
        mu = np.load(row['mu_path'])
        mask = np.load(row['mask_path'])
        
        if self.has_precomputed and not self.compute_on_fly:
            attributes = np.array([
                row['mean_HU'], row['HU_std'], row['HU_p10'], row['HU_p25'],
                row['HU_p50'], row['HU_p75'], row['HU_p90'],
                row['mask_area_pixels'], row['mask_fraction'],
                row['grad_mean'], row['grad_std'],
                row['contrast'], row['homogeneity'], row['entropy']
            ], dtype=np.float32)
        else:
            attributes = self._compute_all_physics(ct, mu, mask)
        
        return {
            'ct': torch.FloatTensor(ct).unsqueeze(0),
            'mu': torch.FloatTensor(mu).unsqueeze(0),
            'mask': torch.FloatTensor(mask).unsqueeze(0),
            'attributes': torch.FloatTensor(attributes),
            'label': torch.tensor(row['label'], dtype=torch.long),
            'id': row['id']
        }

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.feature import graycomatrix, graycoprops
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score
import os
import random

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=False

class CTDataset_ARSIVAE(Dataset):
    def __init__(self,csv_path,compute_on_fly=True,attr_mean=None,attr_std=None):
        self.df=pd.read_csv(csv_path)
        self.compute_on_fly=compute_on_fly
        self.has_precomputed=self._check_precomputed_features()
        self.attr_mean=attr_mean
        self.attr_std=attr_std
        if not self.has_precomputed and not compute_on_fly:
            raise ValueError("CSV missing precomputed features")
    def _check_precomputed_features(self):
        required=['mean_HU','HU_std','HU_p10','HU_p25','HU_p50','HU_p75','HU_p90','mask_area_pixels','mask_fraction','grad_mean','grad_std','contrast','homogeneity','entropy']
        return all(col in self.df.columns for col in required)
    def __len__(self):
        return len(self.df)
    def _compute_hu_features(self,ct,mask):
        lung_pixels=mask>0.5
        hu_values=ct[lung_pixels]
        if len(hu_values)==0:
            return {k:0.0 for k in ['mean_HU','HU_std','HU_p10','HU_p25','HU_p50','HU_p75','HU_p90']}
        return {'mean_HU':float(np.mean(hu_values)),'HU_std':float(np.std(hu_values)),'HU_p10':float(np.percentile(hu_values,10)),'HU_p25':float(np.percentile(hu_values,25)),'HU_p50':float(np.percentile(hu_values,50)),'HU_p75':float(np.percentile(hu_values,75)),'HU_p90':float(np.percentile(hu_values,90))}
    def _compute_shape_features(self,mask,image_size=512*512):
        mask_area=float(np.sum(mask>0.5))
        return {'mask_area_pixels':mask_area,'mask_fraction':mask_area/image_size}
    def _compute_gradient_features(self,ct,mask):
        grad_y,grad_x=np.gradient(ct)
        grad_magnitude=np.sqrt(grad_x**2+grad_y**2)
        lung_pixels=mask>0.5
        grad_in_lung=grad_magnitude[lung_pixels]
        if len(grad_in_lung)==0:
            return {'grad_mean':0.0,'grad_std':0.0}
        return {'grad_mean':float(np.mean(grad_in_lung)),'grad_std':float(np.std(grad_in_lung))}
    def _compute_texture_features(self,ct,mask):
        lung_pixels=mask>0.5
        if lung_pixels.sum()==0:
            return {'contrast':0.0,'homogeneity':1.0,'entropy':0.0}
        ct_masked=ct.copy()
        ct_masked[~lung_pixels]=ct_masked[lung_pixels].min()
        ct_min=ct_masked[lung_pixels].min()
        ct_max=ct_masked[lung_pixels].max()
        if ct_max==ct_min:
            return {'contrast':0.0,'homogeneity':1.0,'entropy':0.0}
        ct_normalized=((ct_masked-ct_min)/(ct_max-ct_min)*255).astype(np.uint8)
        try:
            glcm=graycomatrix(ct_normalized,distances=[1],angles=[0,np.pi/4,np.pi/2,3*np.pi/4],levels=256,symmetric=True,normed=True)
            contrast=graycoprops(glcm,'contrast').mean()
            homogeneity=graycoprops(glcm,'homogeneity').mean()
            glcm_norm=glcm/(glcm.sum()+1e-10)
            entropy=-np.sum(glcm_norm*np.log2(glcm_norm+1e-10))
            return {'contrast':float(contrast),'homogeneity':float(homogeneity),'entropy':float(entropy)}
        except:
            return {'contrast':0.0,'homogeneity':1.0,'entropy':0.0}
    def _compute_all_physics(self,ct,mask):
        ct_hu=ct*1400-1000
        hu_feat=self._compute_hu_features(ct_hu,mask)
        shape_feat=self._compute_shape_features(mask)
        grad_feat=self._compute_gradient_features(ct,mask)
        texture_feat=self._compute_texture_features(ct,mask)
        attributes=np.array([hu_feat['mean_HU'],hu_feat['HU_std'],hu_feat['HU_p10'],hu_feat['HU_p25'],hu_feat['HU_p50'],hu_feat['HU_p75'],hu_feat['HU_p90'],shape_feat['mask_area_pixels'],shape_feat['mask_fraction'],grad_feat['grad_mean'],grad_feat['grad_std'],texture_feat['contrast'],texture_feat['homogeneity'],texture_feat['entropy']],dtype=np.float32)
        return attributes
    def __getitem__(self,idx):
        row=self.df.iloc[idx]
        ct=np.load(row['ct_path'])
        mask=np.load(row['mask_path'])
        if self.has_precomputed and not self.compute_on_fly:
            attributes=np.array([row['mean_HU'],row['HU_std'],row['HU_p10'],row['HU_p25'],row['HU_p50'],row['HU_p75'],row['HU_p90'],row['mask_area_pixels'],row['mask_fraction'],row['grad_mean'],row['grad_std'],row['contrast'],row['homogeneity'],row['entropy']],dtype=np.float32)
        else:
            attributes=self._compute_all_physics(ct,mask)
        if self.attr_mean is not None and self.attr_std is not None:
            attributes=(attributes-self.attr_mean)/(self.attr_std+1e-8)
        return {'ct':torch.FloatTensor(ct).unsqueeze(0),'mask':torch.FloatTensor(mask).unsqueeze(0),'attributes':torch.FloatTensor(attributes),'label':torch.tensor(row['label'],dtype=torch.long),'id':row['id']}

In [None]:
class Encoder(nn.Module):
    def __init__(self,latent_dim=64):
        super().__init__()
        self.conv=nn.Sequential(nn.Conv2d(1,32,4,2,1),nn.BatchNorm2d(32),nn.LeakyReLU(0.2),nn.Conv2d(32,64,4,2,1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.Conv2d(64,128,4,2,1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128,256,4,2,1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256,512,4,2,1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2))
        self.fc_mu=nn.Linear(512*16*16,latent_dim)
        self.fc_logvar=nn.Linear(512*16*16,latent_dim)
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m,(nn.Conv2d,nn.Linear)):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
        nn.init.xavier_normal_(self.fc_mu.weight,gain=0.01)
        nn.init.constant_(self.fc_mu.bias,0)
        nn.init.xavier_normal_(self.fc_logvar.weight,gain=0.01)
        nn.init.constant_(self.fc_logvar.bias,-5)
    def forward(self,x):
        h=self.conv(x).view(x.size(0),-1)
        mu=self.fc_mu(h)
        logvar=torch.clamp(self.fc_logvar(h),-10,2)
        return mu,logvar

class Decoder(nn.Module):
    def __init__(self,latent_dim=64):
        super().__init__()
        self.fc=nn.Linear(latent_dim,512*16*16)
        self.deconv=nn.Sequential(nn.ConvTranspose2d(512,256,4,2,1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.ConvTranspose2d(256,128,4,2,1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.ConvTranspose2d(128,64,4,2,1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.ConvTranspose2d(64,32,4,2,1),nn.BatchNorm2d(32),nn.LeakyReLU(0.2),nn.ConvTranspose2d(32,1,4,2,1),nn.Tanh())
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m,(nn.ConvTranspose2d,nn.Linear)):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
    def forward(self,z):
        h=self.fc(z).view(z.size(0),512,16,16)
        return self.deconv(h)

In [None]:
class AttributePredictor(nn.Module):
    def __init__(self,latent_dim=64,n_attributes=14):
        super().__init__()
        self.input_layer=nn.Linear(latent_dim,256)
        self.bn1=nn.BatchNorm1d(256)
        self.res1=nn.Linear(256,256)
        self.bn_res1=nn.BatchNorm1d(256)
        self.res2=nn.Linear(256,256)
        self.bn_res2=nn.BatchNorm1d(256)
        self.fc2=nn.Linear(256,128)
        self.bn2=nn.BatchNorm1d(128)
        self.fc3=nn.Linear(128,n_attributes)
        self.dropout=nn.Dropout(0.1)
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight,nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
    def forward(self,z):
        x=F.relu(self.bn1(self.input_layer(z)))
        identity=x
        x=F.relu(self.bn_res1(self.res1(x)))
        x=self.bn_res2(self.res2(x))
        x=F.relu(x+identity)
        x=self.dropout(F.relu(self.bn2(self.fc2(x))))
        return self.fc3(x)

In [None]:
class ARSIVAE(nn.Module):
    def __init__(self,latent_dim=64,n_attributes=14):
        super().__init__()
        self.encoder=Encoder(latent_dim)
        self.decoder=Decoder(latent_dim)
        self.attr_predictor=AttributePredictor(latent_dim,n_attributes)
    def reparameterize(self,mu,logvar):
        std=torch.exp(0.5*logvar).clamp(min=1e-4,max=10)
        eps=torch.randn_like(std)
        return mu+eps*std
    def forward(self,x):
        mu,logvar=self.encoder(x)
        z=self.reparameterize(mu,logvar)
        recon=self.decoder(z)
        attrs=self.attr_predictor(mu)
        return recon,mu,logvar,attrs

def loss_function(recon,x,mu,logvar,pred_attrs,true_attrs,beta,lambda_attr):
    recon_loss=F.mse_loss(recon,x,reduction='mean')
    kl_loss=torch.clamp(-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp()),0,1e4)
    attr_loss=F.mse_loss(pred_attrs,true_attrs,reduction='mean')
    total=recon_loss+beta*kl_loss+lambda_attr*attr_loss
    return total,recon_loss,kl_loss,attr_loss

def get_improved_schedule(epoch,num_epochs=50):
    if epoch<15:
        beta=0.001*(epoch/15)
        lambda_attr=8.0
    elif epoch<35:
        progress=(epoch-15)/20
        beta=0.001+0.25*progress
        lambda_attr=8.0-3.0*progress
    else:
        beta=0.25
        lambda_attr=5.0
    return beta,lambda_attr

def plot_reconstructions_epoch(model,loader,device,epoch,save_dir='recon_epochs'):
    os.makedirs(save_dir,exist_ok=True)
    model.eval()
    batch=next(iter(loader))
    x=batch['ct'][:8].to(device)
    with torch.no_grad():
        recon,_,_,_=model(x)
    x=x.cpu().numpy()
    recon=recon.cpu().numpy()
    fig,axes=plt.subplots(2,8,figsize=(16,4))
    for i in range(8):
        axes[0,i].imshow(x[i,0],cmap='gray')
        axes[0,i].axis('off')
        if i==0:
            axes[0,i].set_title('Original')
        axes[1,i].imshow(recon[i,0],cmap='gray')
        axes[1,i].axis('off')
        if i==0:
            axes[1,i].set_title('Reconstructed')
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{save_dir}/recon_epoch_{epoch:03d}.png',dpi=150,bbox_inches='tight')
    plt.close()

In [None]:
def train_epoch(model,loader,optimizer,device,beta,lambda_attr):
    model.train()
    total_loss=recon_loss=kl_loss=attr_loss=0
    n_batches=0
    pbar=tqdm(loader,desc='Training')
    for batch in pbar:
        x=batch['ct'].to(device)
        attrs=batch['attributes'].to(device)
        if torch.isnan(x).any() or torch.isinf(x).any():
            continue
        optimizer.zero_grad()
        recon,mu,logvar,pred_attrs=model(x)
        if torch.isnan(recon).any() or torch.isinf(recon).any():
            continue
        loss,r,k,a=loss_function(recon,x,mu,logvar,pred_attrs,attrs,beta,lambda_attr)
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()
        total_loss+=loss.item()
        recon_loss+=r.item()
        kl_loss+=k.item()
        attr_loss+=a.item()
        n_batches+=1
        pbar.set_postfix({'loss':f'{loss.item():.3f}','recon':f'{r.item():.3f}','kl':f'{k.item():.3f}','attr':f'{a.item():.3f}'})
    if n_batches==0:
        return float('nan'),float('nan'),float('nan'),float('nan')
    return total_loss/n_batches,recon_loss/n_batches,kl_loss/n_batches,attr_loss/n_batches

def validate(model,loader,device,beta,lambda_attr):
    model.eval()
    total_loss=recon_loss=kl_loss=attr_loss=0
    with torch.no_grad():
        for batch in loader:
            x=batch['ct'].to(device)
            attrs=batch['attributes'].to(device)
            recon,mu,logvar,pred_attrs=model(x)
            loss,r,k,a=loss_function(recon,x,mu,logvar,pred_attrs,attrs,beta,lambda_attr)
            total_loss+=loss.item()
            recon_loss+=r.item()
            kl_loss+=k.item()
            attr_loss+=a.item()
    n=len(loader)
    return total_loss/n,recon_loss/n,kl_loss/n,attr_loss/n

def train_improved(model,train_loader,val_loader,device,epochs=50):
    enc_params=list(model.encoder.parameters())
    dec_params=list(model.decoder.parameters())
    attr_params=list(model.attr_predictor.parameters())
    optimizer=optim.AdamW([{'params':enc_params,'lr':1e-4},{'params':dec_params,'lr':1e-4},{'params':attr_params,'lr':5e-4}],weight_decay=1e-5)
    scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs)
    history={'train_total':[],'val_total':[],'train_recon':[],'val_recon':[],'train_kl':[],'val_kl':[],'train_attr':[],'val_attr':[],'beta':[],'lambda':[]}
    best_val_attr_loss=float('inf')
    best_epoch=0
    for epoch in range(epochs):
        beta,lambda_attr=get_improved_schedule(epoch,epochs)
        history['beta'].append(beta)
        history['lambda'].append(lambda_attr)
        train_loss,train_r,train_k,train_a=train_epoch(model,train_loader,optimizer,device,beta,lambda_attr)
        val_loss,val_r,val_k,val_a=validate(model,val_loader,device,beta,lambda_attr)
        scheduler.step()
        history['train_total'].append(train_loss)
        history['val_total'].append(val_loss)
        history['train_recon'].append(train_r)
        history['val_recon'].append(val_r)
        history['train_kl'].append(train_k)
        history['val_kl'].append(val_k)
        history['train_attr'].append(train_a)
        history['val_attr'].append(val_a)
        phase="Physics" if epoch<15 else "Balance" if epoch<35 else "Finetune"
        print(f"Epoch {epoch+1}/{epochs} [{phase}] beta={beta:.4f} lambda={lambda_attr:.2f}")
        print(f"Train: Total={train_loss:.4f} Recon={train_r:.4f} KL={train_k:.4f} Attr={train_a:.4f}")
        print(f"Val: Total={val_loss:.4f} Recon={val_r:.4f} KL={val_k:.4f} Attr={val_a:.4f}")
        if(epoch+1)%5==0:
            plot_reconstructions_epoch(model,val_loader,device,epoch+1)
            print(f"Saved reconstruction for epoch {epoch+1}")
        if val_a<best_val_attr_loss:
            best_val_attr_loss=val_a
            best_epoch=epoch+1
            torch.save(model.state_dict(),'best_arsivae_improved.pth')
            print(f"Best model saved val_attr_loss={val_a:.4f}")
    print(f"Best model from epoch {best_epoch} with val_attr_loss={best_val_attr_loss:.4f}")
    return model,history

In [None]:
def extract_features(model,loader,device):
    model.eval()
    latents,labels,pred_attrs,true_attrs=[],[],[],[]
    with torch.no_grad():
        for batch in loader:
            x=batch['ct'].to(device)
            mu,_=model.encoder(x)
            attrs=model.attr_predictor(mu)
            latents.append(mu.cpu().numpy())
            labels.append(batch['label'].cpu().numpy())
            pred_attrs.append(attrs.cpu().numpy())
            true_attrs.append(batch['attributes'].cpu().numpy())
    return {'latents':np.vstack(latents),'labels':np.concatenate(labels),'pred_attrs':np.vstack(pred_attrs),'true_attrs':np.vstack(true_attrs)}

def plot_training_curves(history,save_path='training_curves.png'):
    fig,axes=plt.subplots(2,3,figsize=(15,8))
    epochs=range(1,len(history['train_total'])+1)
    axes[0,0].plot(epochs,history['train_total'],'b-',label='Train')
    axes[0,0].plot(epochs,history['val_total'],'r-',label='Val')
    axes[0,0].set_title('Total Loss')
    axes[0,0].legend()
    axes[0,0].grid(alpha=0.3)
    axes[0,1].plot(epochs,history['train_recon'],'b-',label='Train')
    axes[0,1].plot(epochs,history['val_recon'],'r-',label='Val')
    axes[0,1].set_title('Reconstruction Loss')
    axes[0,1].legend()
    axes[0,1].grid(alpha=0.3)
    axes[0,2].plot(epochs,history['train_kl'],'b-',label='Train')
    axes[0,2].plot(epochs,history['val_kl'],'r-',label='Val')
    axes[0,2].set_title('KL Divergence')
    axes[0,2].legend()
    axes[0,2].grid(alpha=0.3)
    axes[1,0].plot(epochs,history['train_attr'],'b-',label='Train')
    axes[1,0].plot(epochs,history['val_attr'],'r-',label='Val')
    axes[1,0].set_title('Attribute Loss')
    axes[1,0].legend()
    axes[1,0].grid(alpha=0.3)
    axes[1,1].plot(epochs,history['beta'],'purple')
    axes[1,1].set_title('Beta Schedule')
    axes[1,1].grid(alpha=0.3)
    axes[1,2].plot(epochs,history['lambda'],'orange')
    axes[1,2].set_title('Lambda Schedule')
    axes[1,2].grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path,dpi=300,bbox_inches='tight')
    plt.close()

def plot_physics_alignment(data,save_path='physics_alignment.png'):
    pred=data['pred_attrs']
    true=data['true_attrs']
    attr_names=['mean_HU','HU_std','HU_p10','HU_p25','HU_p50','HU_p75','HU_p90','mask_area','mask_frac','grad_mean','grad_std','contrast','homog','entropy']
    fig,axes=plt.subplots(3,5,figsize=(20,12))
    axes=axes.flatten()
    r2_scores=[]
    for i in range(14):
        ax=axes[i]
        ax.scatter(true[:,i],pred[:,i],alpha=0.3,s=10,color='steelblue')
        min_val=min(true[:,i].min(),pred[:,i].min())
        max_val=max(true[:,i].max(),pred[:,i].max())
        ax.plot([min_val,max_val],[min_val,max_val],'r--')
        r2=r2_score(true[:,i],pred[:,i])
        r2_scores.append(r2)
        ax.set_xlabel(f'True {attr_names[i]}')
        ax.set_ylabel(f'Pred {attr_names[i]}')
        ax.set_title(f'{attr_names[i]} R2={r2:.3f}')
        ax.grid(alpha=0.3)
    axes[14].axis('off')
    plt.tight_layout()
    plt.savefig(save_path,dpi=300,bbox_inches='tight')
    plt.close()
    return r2_scores,np.mean(r2_scores)

def plot_latent_space(data,save_path='latent_space.png'):
    latents=data['latents']
    labels=data['labels']
    pred_attrs=data['pred_attrs']
    pca=PCA(n_components=2)
    latent_pca=pca.fit_transform(latents)
    fig,axes=plt.subplots(2,3,figsize=(18,12))
    ax=axes[0,0]
    colors=['#3498db','#e74c3c']
    for i,label_name in enumerate(['Normal','COVID']):
        mask=labels==i
        ax.scatter(latent_pca[mask,0],latent_pca[mask,1],c=colors[i],label=label_name,alpha=0.6,s=30)
    ax.set_title('Class Separation')
    ax.legend()
    ax.grid(alpha=0.3)
    physics_features=[('mean_HU',0,'Mean HU'),('grad_mean',9,'Gradient Mean'),('entropy',13,'Entropy'),('mask_area',7,'Mask Area'),('contrast',11,'Contrast')]
    for idx,(name,attr_idx,title) in enumerate(physics_features):
        ax=axes.flatten()[idx+1]
        scatter=ax.scatter(latent_pca[:,0],latent_pca[:,1],c=pred_attrs[:,attr_idx],cmap='viridis',alpha=0.6,s=30)
        ax.set_title(f'{title}')
        plt.colorbar(scatter,ax=ax,label=name)
        ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path,dpi=300,bbox_inches='tight')
    plt.close()

def compute_normalization_stats(dataset):
    all_attrs=[]
    for i in tqdm(range(len(dataset)),desc="Computing stats"):
        sample=dataset[i]
        all_attrs.append(sample['attributes'].numpy())
    all_attrs=np.vstack(all_attrs)
    mean=all_attrs.mean(axis=0)
    std=all_attrs.std(axis=0)
    std=np.where(std<1e-6,1.0,std)
    return mean,std

In [None]:
def main():
    SEED=42
    set_seed(SEED)
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_dataset_unnorm=CTDataset_ARSIVAE('train.csv',compute_on_fly=True)
    attr_mean,attr_std=compute_normalization_stats(train_dataset_unnorm)
    train_dataset=CTDataset_ARSIVAE('train.csv',compute_on_fly=True,attr_mean=attr_mean,attr_std=attr_std)
    val_dataset=CTDataset_ARSIVAE('val.csv',compute_on_fly=True,attr_mean=attr_mean,attr_std=attr_std)
    test_dataset=CTDataset_ARSIVAE('test.csv',compute_on_fly=True,attr_mean=attr_mean,attr_std=attr_std)
    train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True,worker_init_fn=lambda worker_id:np.random.seed(SEED+worker_id))
    val_loader=DataLoader(val_dataset,batch_size=32,shuffle=False,num_workers=4,pin_memory=True)
    test_loader=DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=4,pin_memory=True)
    model=ARSIVAE(latent_dim=64,n_attributes=14).to(device)
    NUM_EPOCHS=50
    model,history=train_improved(model,train_loader,val_loader,device,epochs=NUM_EPOCHS)
    model.load_state_dict(torch.load('best_arsivae_improved.pth'))
    plot_training_curves(history,'training_curves.png')
    train_data=extract_features(model,train_loader,device)
    val_data=extract_features(model,val_loader,device)
    test_data=extract_features(model,test_loader,device)
    r2_scores,avg_r2=plot_physics_alignment(val_data,'val_physics_alignment.png')
    plot_latent_space(val_data,'val_latent_space.png')
    np.save('train_latents.npy',train_data['latents'])
    np.save('val_latents.npy',val_data['latents'])
    np.save('test_latents.npy',test_data['latents'])
    np.save('train_labels.npy',train_data['labels'])
    np.save('val_labels.npy',val_data['labels'])
    np.save('test_labels.npy',test_data['labels'])
    np.save('attr_mean.npy',attr_mean)
    np.save('attr_std.npy',attr_std)
    print(f"Training complete. Avg R2={avg_r2:.4f}")
    return model,history,val_data

if __name__=='__main__':
    model,history,val_data=main()