In [12]:
import argparse
import torch
import numpy as np
import random
import os 
from model import BaselineModel
from dataloader import *
from train_utils import train_baseline, train_coral, train_adversarial, train_adabn
from plot_utils import *
from analysis_utils import *

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Training

In [9]:
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        self.source_loader, self.target_loader = prepare_data(batch_size=args.batch_size)

    def train_baseline(self):
        model = BaselineModel().to(self.device)
        return train_baseline(model, self.source_loader, self.target_loader, 
                            self.args, self.device)

    def train_coral(self):
        model = BaselineModel().to(self.device)
        return train_coral(model, self.source_loader, self.target_loader, 
                          self.args, self.device)

    def train_adversarial(self):
        model = BaselineModel().to(self.device)
        return train_adversarial(model, self.source_loader, self.target_loader, 
                               self.args, self.device)

    def train_adabn(self):
        model = BaselineModel().to(self.device)
        return train_adabn(model, self.source_loader, self.target_loader, 
                          self.args, self.device)



def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    set_seed(args.seed)
    trainer = Trainer(args, device)
    
    results = {}
    
    if args.method == 'baseline' or args.method == 'all':
        results['baseline'] = trainer.train_baseline()
    
    if args.method == 'coral' or args.method == 'all':
        results['coral'] = trainer.train_coral()
    
    if args.method == 'adversarial' or args.method == 'all':
        results['adversarial'] = trainer.train_adversarial()
    
    if args.method == 'adabn' or args.method == 'all':
        results['adabn'] = trainer.train_adabn()
    
    print("\nFinal Target Accuracies:")
    for model_name, model_results in results.items():
        print(f"{model_name}: {model_results['final_target_acc']:.4f}")
    
    # Save results with timestamp
    import pickle
    from datetime import datetime
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    os.makedirs('results', exist_ok=True)
    with open(f'results/results_{timestamp}.pkl', 'wb') as f:
        pickle.dump(results, f)
    
    return results

In [None]:

# For notebook usage
parser = argparse.ArgumentParser(description='Domain Adaptation Methods')
parser.add_argument('--method', type=str, default='all',
                    choices=['baseline', 'coral', 'adversarial', 'adabn', 'all'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--coral_weight', type=float, default=1.0)
parser.add_argument('--adversarial_weight', type=float, default=1.0)
args = parser.parse_args([]) 

results = main(args)

In [None]:
plot_training_curves(results, save=True, show=True)

# Analysis

### Per-class performance

In [11]:
# Create fresh model instances and load the saved states
def load_model_with_states(path, device="cpu"):
    model = BaselineModel().to(device)
    state_dict = torch.load(path, map_location=device)
    
    # Handle different save formats
    if isinstance(state_dict, dict) and 'model' in state_dict:
        # For adversarial model that includes discriminator
        model.load_state_dict(state_dict['model'])
    else:
        # For other models
        model.load_state_dict(state_dict)
    
    return model

# Create models dictionary with proper model instances
models_dict = {
    'baseline': load_model_with_states("models/baseline_final.pth", device="cpu"),
    'coral': load_model_with_states("models/coral_final.pth", device="cpu"),
    'adversarial': load_model_with_states("models/adversarial_final.pth", device="cpu"),
    'adabn': load_model_with_states("models/adabn_final.pth", device="cpu")
}

# Optional: Define your class names (fault types)
class_names = [f'Fault {i}' for i in range(10)]  # Replace with actual fault names if available

trainer = Trainer(args, device="cpu")

# Run the analysis
performance_results = analyze_per_class_performance(
    models_dict,
    trainer.source_loader,
    trainer.target_loader,
    device="cpu",
    class_names=class_names,
    save_plots=True
)

# Print domain shift analysis
print_domain_shift_analysis(performance_results, class_names)

  state_dict = torch.load(path, map_location=device)



Analyzing baseline...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



baseline Summary:

Source Domain Performance:
Overall Accuracy: 1.0000

Target Domain Performance:
Overall Accuracy: 0.8000

Per-class F1 scores (Source → Target):
Fault 0: 1.0000 → 1.0000
Fault 1: 1.0000 → 0.0000
Fault 2: 1.0000 → 1.0000
Fault 3: 1.0000 → 1.0000
Fault 4: 1.0000 → 0.6667
Fault 5: 1.0000 → 1.0000
Fault 6: 1.0000 → 0.0000
Fault 7: 1.0000 → 1.0000
Fault 8: 1.0000 → 1.0000
Fault 9: 1.0000 → 0.6667

Analyzing coral...

coral Summary:

Source Domain Performance:
Overall Accuracy: 1.0000

Target Domain Performance:
Overall Accuracy: 0.8950

Per-class F1 scores (Source → Target):
Fault 0: 1.0000 → 1.0000
Fault 1: 1.0000 → 1.0000
Fault 2: 1.0000 → 1.0000
Fault 3: 1.0000 → 0.9474
Fault 4: 1.0000 → 0.7843
Fault 5: 1.0000 → 1.0000
Fault 6: 1.0000 → 0.4615
Fault 7: 1.0000 → 0.7917
Fault 8: 1.0000 → 1.0000
Fault 9: 1.0000 → 0.8649

Analyzing adversarial...

adversarial Summary:

Source Domain Performance:
Overall Accuracy: 1.0000

Target Domain Performance:
Overall Accuracy: 1.0000

In [13]:
# Run the feature space analysis
feature_analysis_results = analyze_feature_space(
    models_dict,
    trainer.source_loader,
    trainer.target_loader,
    device="cpu"
)

# The analysis will create several visualizations in the plots/feature_space/ directory:
# 1. t-SNE visualizations for each model
# 2. PCA visualizations for each model
# 3. Comparative bar plots of domain distances


Analyzing feature space for baseline...

baseline Domain Statistics:
Mean Distance: 2.0196
Covariance Distance: 15.1903
MMD Distance: 4.0787

Analyzing feature space for coral...

coral Domain Statistics:
Mean Distance: 0.0641
Covariance Distance: 0.0401
MMD Distance: 0.0041

Analyzing feature space for adversarial...

adversarial Domain Statistics:
Mean Distance: 1.2607
Covariance Distance: 10.5836
MMD Distance: 1.5894

Analyzing feature space for adabn...

adabn Domain Statistics:
Mean Distance: 1.6048
Covariance Distance: 13.6504
MMD Distance: 2.5753
