In [None]:
import os
import sys
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import datasets, transforms
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
from sklearn.model_selection import ParameterGrid
import json
warnings.filterwarnings('ignore')
from replknet import RepLKNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.device_count() == 2:
    print(f"Using 2 GPUs!")
else:
    print(f"Warning: Expected 2 GPUs, but found {torch.cuda.device_count()}.")

def convert_sync_batchnorm(module):
    module_output = module
    if isinstance(module, torch.nn.SyncBatchNorm):
        module_output = torch.nn.BatchNorm2d(module.num_features,
                                             module.eps, module.momentum,
                                             module.affine,
                                             module.track_running_stats)
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, convert_sync_batchnorm(child))
    del module
    return module_output

def load_model(model_config, checkpoint_path):
    model = RepLKNet(
        large_kernel_sizes=[31,29,27,13],
        layers=[2,2,18,2],
        channels=model_config["channels"],
        drop_path_rate=model_config["drop_path_rate"],
        small_kernel=5,
        num_classes=model_config["class_n"]
    )
    
    model = convert_sync_batchnorm(model)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict and 'head' not in k}
    model_dict.update(pretrained_dict)
    
    model.load_state_dict(model_dict, strict=False)
    
    in_features = model.head.in_features
    model.head = nn.Linear(in_features, model_config["class_n"])
    
    return model

def get_data_loaders(train_dir, val_dir, input_size, batch_size):
    transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    train_dataset = datasets.ImageFolder(train_dir, transform=transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    
    return train_loader, val_loader

def train(model, train_loader, val_loader, criterion, optimizer, num_epochs, save_path):
    best_val_acc = 0.0
    training_log = []
    checkpoint_dir = os.path.join('/workspace/REPLKNET/fineee', os.path.basename(os.path.dirname(save_path)))
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        for inputs, labels in train_pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            train_pbar.set_postfix({
                'train_loss': f"{train_loss/train_total:.4f}",
                'train_acc': f"{100.*train_correct/train_total:.2f}%"
            })
        
        val_loss, val_acc = validate(model, val_loader, criterion)
        
        train_acc = 100. * train_correct / train_total
        
        epoch_log = {
            'epoch': epoch + 1,
            'train_loss': train_loss / train_total,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        }
        training_log.append(epoch_log)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train - Loss: {train_loss/train_total:.4f}, Accuracy: {train_acc:.2f}%")
        print(f"Val   - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
        
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss / train_total,
                'val_loss': val_loss,
                'val_acc': val_acc
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved to {save_path}")

    # Save training log to JSON file
    log_path = os.path.join(checkpoint_dir, 'training_log.json')
    with open(log_path, 'w') as f:
        json.dump(training_log, f, indent=2)
    print(f"Training log saved to {log_path}")

    return best_val_acc

def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Validating")
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'val_loss': f"{val_loss/total:.4f}",
                'val_acc': f"{100.*correct/total:.2f}%"
            })
    
    return val_loss/total, 100.*correct/total

            
def grid_search(model_paths, param_grid, train_dir, val_dir):
    results = []

    for model_path in model_paths:
        input_size = 384 if "384" in model_path else 224
        
        for params in ParameterGrid(param_grid):
            print(f"Training with parameters: {params}")
            
            model_config = {
                "class_n": 2,
                "unit_n": 1024,
                "input_size": input_size,
                "size": 12,
                "lr": params['learning_rate'],
                "weight_decay": params['weight_decay'],
                "channels": [128, 256, 512, 1024],
                "model_name": "RepLKNet-31B",
                "drop_path_rate": params['drop_rate'],
                "batch_size": params['batch_size'],
                "epochs": params['epochs']
            }
            
            train_loader, val_loader = get_data_loaders(train_dir, val_dir, input_size, params['batch_size'])
            
            model = load_model(model_config, model_path)
            model = nn.DataParallel(model).to(device)  

            criterion = nn.CrossEntropyLoss()
            optimizer = optim.AdamW(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])

            
            model_name = f"RepLKNet_{input_size}_dr{params['drop_rate']}_lr{params['learning_rate']}_wd{params['weight_decay']}_e{params['epochs']}_b{params['batch_size']}.pth"
            save_path = os.path.join("/save/path", model_name)

            best_val_acc = train(model, train_loader, val_loader, criterion, optimizer, params['epochs'], save_path, model_config)

            results.append({
                'model': model_name,
                'params': params,
                'best_val_acc': best_val_acc,
                'save_path': save_path
            })

    return results
if __name__ == "__main__":
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        print("The start method has already been set. Using the existing start method.")
    
    model_paths = [
        'REPLKNET_model_path',
    ]

    param_grid = {
        'drop_rate': [0.3, 0.5, 0.7],
        'learning_rate': [1e-5, 1e-4],
        'weight_decay': [1e-4, 1e-3],
        'epochs': [10],
        'batch_size': [32, 64]  
    }

    train_dir = "path/train_dataset"
    val_dir = "path/val_dataset"

    results = grid_search(model_paths, param_grid, train_dir, val_dir)

    with open('grid_search_results.json', 'w') as f:
        json.dump(results, f, indent=2)

    best_result = max(results, key=lambda x: x['best_val_acc'])
    print("Best Result:")
    print(json.dumps(best_result, indent=2))