In [None]:
# Setup and imports
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='PIL.Image')

DATASET_DIR = Path(os.getcwd()) / "datasets"


In [None]:
# WandB initialization
def init_wandb():
    try:
        import wandb
        return wandb.init(
            entity="prinzz-personal",
            project="gadgets-predictor",
            config={
                "learning_rate": 0.02,
                "architecture": "ResNet18",
                "epochs": 3,
                "batch_size": 64,
                "image_size": 224
            }
        )
    except Exception as e:
        print(f"WandB initialization failed: {e}")
        return None


In [None]:
# Data loading and preprocessing
from fastai.vision.all import ImageDataLoaders, Resize, Transform
from PIL import Image

def convert_to_rgb(img):
    """Convert images to RGB format, handling palette images with transparency"""
    if img.mode == "P":
        img = img.convert("RGBA") if "transparency" in img.info else img.convert("RGB")
    elif img.mode in ("RGBA", "LA"):
        img = img.convert("RGB")
    return img

class ConvertToRGB(Transform):
    def encodes(self, img):
        return convert_to_rgb(img)

# Create data loaders
dls = ImageDataLoaders.from_folder(
    DATASET_DIR,
    train_pct=0.8,
    valid_pct=0.2,
    item_tfms=[ConvertToRGB(), Resize(224)]
)

print(f"Classes: {dls.vocab}")
print(f"Training samples: {len(dls.train_ds)}")
print(f"Validation samples: {len(dls.valid_ds)}")


In [None]:
# Model training with WandB logging
from fastai.vision.all import vision_learner, resnet18, error_rate, accuracy
from fastai.callback.tracker import TrackerCallback
import wandb

class WandbMetricsCallback(TrackerCallback):
    """Custom callback to log training metrics to WandB"""
    def __init__(self, wandb_run, monitor='valid_loss'):
        super().__init__(monitor=monitor)
        self.wandb_run = wandb_run
        
    def after_epoch(self):
        super().after_epoch()
        if self.wandb_run and hasattr(self.wandb_run, 'log'):
            try:
                last_values = self.learn.recorder.values[-1]
                metric_names = ['train_loss', 'valid_loss'] + [m.name for m in self.learn.metrics]
                
                metrics = {'epoch': self.epoch}
                for i, value in enumerate(last_values):
                    if i < len(metric_names):
                        metrics[metric_names[i]] = float(value)
                
                # Add learning rate
                if hasattr(self.learn, 'opt') and self.learn.opt:
                    metrics['learning_rate'] = self.learn.opt.hypers[0]['lr']
                
                self.wandb_run.log(metrics)
                print(f"✓ Epoch {self.epoch} metrics logged to WandB")
            except Exception as e:
                print(f"✗ Failed to log epoch {self.epoch}: {e}")

# Initialize WandB and create model
wandb_run = init_wandb()
learn = vision_learner(dls, resnet18, metrics=[accuracy, error_rate], pretrained=True)

# Train with WandB logging
callbacks = [WandbMetricsCallback(wandb_run)] if wandb_run else []
learn.fine_tune(3, cbs=callbacks)


In [None]:
# Model evaluation and results
learn.show_results(max_n=6, figsize=(12, 8))


In [None]:
# Save model and upload to WandB
import time
model_name = "gadget_classifier_model"
model_path = Path(os.path.join(os.getcwd(), f"models/{model_name}.pkl"))
if not model_path.exists():
  os.makedirs(model_path.parent, exist_ok=True)

try:
    # Export the trained model
    print("Saving model...")

    learn.export(model_path)
    
    # Wait a moment for file system to sync
    time.sleep(1)
    
    # Verify model file exists and get size
    if model_path.exists():
        file_size = model_path.stat().st_size / (1024*1024)
        print(f"✓ Model saved as {model_name}.pkl ({file_size:.2f} MB)")
    else:
        raise FileNotFoundError(f"Model file not found at {model_path}")
    
    if wandb_run:
        
        # Create and upload model artifact
        print("Creating WandB artifact...")
        model_artifact = wandb.Artifact(
            name=model_name,
            type="model",
            description="FastAI ResNet18 model for gadget classification",
            metadata={
                "architecture": "ResNet18",
                "classes": list(dls.vocab),
                "accuracy": float(learn.recorder.values[-1][1]),
                "error_rate": float(learn.recorder.values[-1][2])
            }
        )
        
        # Add the model file to artifact
        model_artifact.add_file(str(model_path))
        
        # Upload artifact to WandB
        print("Uploading model to WandB...")
        wandb_run.log_artifact(model_artifact)
        
        print("✓ Model uploaded to WandB as artifact")
        print(f"✓ Artifact name: {model_name}")
        print(f"✓ Classes: {dls.vocab}")
        print(f"✓ Final accuracy: {learn.recorder.values[-1][1]:.4f}")
            
        wandb_run.finish()
        print("✓ WandB run completed successfully")
    else:
        print("⚠️ No WandB run - model saved locally only")
        
except Exception as e:
    print(f"✗ Error during model save/upload: {e}")
    import traceback
    traceback.print_exc()
    if wandb_run:
        wandb_run.finish()
