In [None]:
import sys
sys.path.append('..')

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import time

from src.dataset_original import STEM4DDataset
from src.model_autoencoder import CNNAutoencoder, ResNetDenoiser
from src.losses_original import CombinedLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Cell 2: Data loading
data_path = Path("../data")
results_path = Path("../results")
results_path.mkdir(exist_ok=True)
(results_path / 'checkpoints').mkdir(exist_ok=True)
(results_path / 'figures').mkdir(exist_ok=True)

print("Loading data...")
low_dose = np.load(data_path / "03_denoising_SrTiO3_High_mag_Low_dose.npy")
high_dose = np.load(data_path / "03_denoising_SrTiO3_High_mag_High_dose.npy")

print(f"Low dose shape: {low_dose.shape}")
print(f"Low dose range: [{low_dose.min():.2f}, {low_dose.max():.2f}]")

In [None]:
# Cell 3: Creating a mask
pacbed_high = np.mean(high_dose, axis=(0, 1))
threshold = 0.1 * pacbed_high.max()
bf_mask = pacbed_high > threshold

print(f"Bright field pixels: {bf_mask.sum()} / {bf_mask.size}")

In [None]:
# Cell 4: Creating a Dataset and DataLoader
batch_size = 8  
window_size = 3  

dataset = STEM4DDataset(
    noisy_data=low_dose,
    window_size=window_size,
    bright_field_mask=bf_mask
)

print(f"Dataset size: {len(dataset)}")


n_samples = len(dataset)
n_val = int(n_samples * 0.1)
n_train = n_samples - n_val

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [n_train, n_val],
    generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)

val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)

In [None]:
# Cell 5: Model selection and initialisation
model_type = "cnn"  

if model_type == "cnn":
    model = CNNAutoencoder(in_channels=8, base_features=16).to(device)
    model_name = "CNN_Autoencoder"

print(f"Model: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


criterion = CombinedLoss(warmup_epochs=8, pacbed_weight=0.02, stem_weight=0.01)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=5, factor=0.5
)

In [None]:
# Cell 6: Learning functions
def train_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    criterion.set_epoch(epoch)
    total_loss = 0
    
    pbar = tqdm(loader, desc=f'Training epoch {epoch}')
    for inputs, targets, _ in pbar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)

def validate_epoch(model, loader, criterion, device, epoch):
    model.eval()
    criterion.set_epoch(epoch)
    total_loss = 0
    
    with torch.no_grad():
        for inputs, targets, _ in tqdm(loader, desc='Validation'):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
    
    return total_loss / len(loader)

In [None]:
# Cell 7: Training
num_epochs = 124
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}

print(f"Starting UNSUPERVISED training with {model_name}...")
start_time = time.time()

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
    val_loss = validate_epoch(model, val_loader, criterion, device, epoch)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss
        }, results_path / 'checkpoints' / f'best_model_{model_type}.pth')
        print("✓ Saved best model!")

print(f"Training completed in {(time.time() - start_time)/60:.2f} minutes")


plt.figure(figsize=(10, 6))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'Training History ({model_name})')
plt.legend()
plt.grid(True)
plt.savefig(results_path / 'figures' / f'training_history_{model_type}.png')
plt.show()

In [None]:
def denoise_with_model(model, noisy_data, bf_mask=None, batch_size=64, device='cpu'):
    """Применить деноизинг ко всему датасету с любой моделью"""
    scan_x, scan_y, det_x, det_y = noisy_data.shape
    denoised = np.zeros_like(noisy_data)
    
    offset = 1  
    valid_positions = []
    for x in range(offset, scan_x - offset):
        for y in range(offset, scan_y - offset):
            valid_positions.append((x, y))
    
    model.eval()
    with torch.no_grad():
        for batch_start in tqdm(range(0, len(valid_positions), batch_size), 
                               desc=f"Denoising with {model.__class__.__name__}"):
            batch_end = min(batch_start + batch_size, len(valid_positions))
            batch_positions = valid_positions[batch_start:batch_end]
            
            batch_inputs = []
            for x, y in batch_positions:
                neighbors = []
                for i in range(3):
                    for j in range(3):
                        if i == 1 and j == 1:
                            continue
                        nx, ny = x - 1 + i, y - 1 + j
                        pattern = noisy_data[nx, ny].astype(np.float32)
                        if bf_mask is not None:
                            pattern = pattern * bf_mask
                        neighbors.append(pattern)
                
                batch_inputs.append(np.stack(neighbors))
            
            inputs = torch.FloatTensor(np.stack(batch_inputs)).to(device)
            outputs = model(inputs)
            
            for i, (x, y) in enumerate(batch_positions):
                denoised[x, y] = outputs[i, 0].cpu().numpy()
    
    
    denoised[0, :] = noisy_data[0, :]
    denoised[-1, :] = noisy_data[-1, :]
    denoised[:, 0] = noisy_data[:, 0]
    denoised[:, -1] = noisy_data[:, -1]
    
    return denoised

In [None]:
print(f"\nDenoising with CNN...")
denoised = denoise_with_model(model, low_dose, bf_mask, device=device)
denoised_results = denoised
    

np.save(results_path / f'denoised_CNN.npy', denoised)
