# Test 04: Torchvision Downloads

**Purpose:** Download torchvision pretrained models and verify they go to the correct location.

**What we'll download:**
1. A small pretrained model (ResNet18)
2. Using both torchvision.models and torch.hub APIs

**Expected Locations:**
- **CoCalc Home:** `~/home_workspace/downloads/` (TORCH_HOME)
- **Compute Server:** `~/cs_workspace/downloads/` (TORCH_HOME)
- **NOT:** `~/.cache/torch/`

**Run this on:** Both CoCalc base and Compute Server

---

## IMPORTANT: Restart kernel before running this test!

In [None]:
# DS776 Environment Setup & Package Update
# Configures storage paths for proper cleanup/sync, then updates introdl if needed
# If this cell fails, see Lessons/Course_Tools/AUTO_UPDATE_SYSTEM.md for help
%run ../../Lessons/Course_Tools/auto_update_introdl.py

In [None]:
# Pre-download check
from pathlib import Path
import os

home = Path.home()
bad_cache = home / '.cache' / 'torch'

print("=" * 60)
print("PRE-DOWNLOAD CHECK: ~/.cache/torch")
print("=" * 60)

if bad_cache.exists():
    total_size = sum(f.stat().st_size for f in bad_cache.rglob('*') if f.is_file())
    print(f"WARNING: ~/.cache/torch exists ({total_size / 1024 / 1024:.1f} MB)")
    print("Contents:")
    for item in bad_cache.iterdir():
        if item.is_dir():
            size = sum(f.stat().st_size for f in item.rglob('*') if f.is_file())
            print(f"  {item.name}/: {size / 1024 / 1024:.1f} MB")
    print("\nNote: This is pre-existing content, not from this test.")
else:
    print("Good: ~/.cache/torch does not exist")

In [None]:
# Check expected cache location
import torch
import os

print("\n" + "=" * 60)
print("EXPECTED CACHE LOCATION")
print("=" * 60)

torch_home = os.environ.get('TORCH_HOME', 'NOT SET')
print(f"TORCH_HOME: {torch_home}")
print(f"torch.hub.get_dir(): {torch.hub.get_dir()}")

# These should match
if torch_home in torch.hub.get_dir():
    print("\nGood: torch.hub is using TORCH_HOME")
else:
    print("\nWARNING: torch.hub may not be using TORCH_HOME")

## Test A: Download via torchvision.models API

In [None]:
# Download ResNet18 using torchvision.models
import torchvision.models as models

print("\nDownloading ResNet18 via torchvision.models...")
print("(This may take a minute on first run...)\n")

# This will download the pretrained weights
model = models.resnet18(weights='IMAGENET1K_V1')
print(f"Model downloaded successfully")
print(f"Model type: {type(model).__name__}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Verify download location
from pathlib import Path
import os
import torch

print("\n" + "=" * 60)
print("TORCHVISION DOWNLOAD VERIFICATION")
print("=" * 60)

home = Path.home()
torch_home = Path(os.environ.get('TORCH_HOME', ''))
hub_dir = Path(torch.hub.get_dir())

# Check expected location (hub/checkpoints)
expected_checkpoints = hub_dir / 'checkpoints'
if expected_checkpoints.exists():
    print(f"\nCorrect location ({expected_checkpoints}):")
    for item in expected_checkpoints.iterdir():
        if 'resnet18' in item.name.lower():
            size = item.stat().st_size / 1024 / 1024
            print(f"  FOUND: {item.name} ({size:.1f} MB)")
else:
    print(f"\nExpected location not found: {expected_checkpoints}")
    # Check hub_dir directly
    print(f"Checking {hub_dir}:")
    if hub_dir.exists():
        for item in hub_dir.iterdir():
            print(f"  {item.name}")

# Check bad location
bad_checkpoints = home / '.cache' / 'torch' / 'hub' / 'checkpoints'
if bad_checkpoints.exists():
    resnet_in_bad = [f for f in bad_checkpoints.iterdir() if 'resnet18' in f.name.lower()]
    if resnet_in_bad:
        print(f"\nWARNING: ResNet18 found in ~/.cache/torch/!")
        for f in resnet_in_bad:
            print(f"  {f.name}")
    else:
        print(f"\nGood: No new ResNet18 in ~/.cache/torch/")
else:
    print(f"\nGood: ~/.cache/torch/hub/checkpoints does not exist")

## Test B: Download via torch.hub API

In [None]:
# Download using torch.hub (different API, same cache)
import torch

print("\nDownloading MobileNetV2 via torch.hub...")
print("(This may take a minute on first run...)\n")

# This uses the torch.hub API
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
print(f"Model downloaded successfully")
print(f"Model type: {type(model).__name__}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Verify torch.hub download location
from pathlib import Path
import os
import torch

print("\n" + "=" * 60)
print("TORCH.HUB DOWNLOAD VERIFICATION")
print("=" * 60)

home = Path.home()
hub_dir = Path(torch.hub.get_dir())

print(f"\nHub directory: {hub_dir}")
if hub_dir.exists():
    print("Contents:")
    for item in hub_dir.iterdir():
        if item.is_dir():
            size = sum(f.stat().st_size for f in item.rglob('*') if f.is_file())
            print(f"  {item.name}/: {size / 1024 / 1024:.1f} MB")
        else:
            size = item.stat().st_size / 1024 / 1024
            print(f"  {item.name}: {size:.1f} MB")

# Check for mobilenet in checkpoints
checkpoints = hub_dir / 'checkpoints'
if checkpoints.exists():
    print(f"\nCheckpoints directory:")
    for f in checkpoints.iterdir():
        if 'mobilenet' in f.name.lower():
            print(f"  FOUND: {f.name}")

## Final Summary

In [None]:
# Final summary
from pathlib import Path
import os
import torch

print("\n" + "=" * 60)
print("FINAL SUMMARY: Torchvision/PyTorch Downloads")
print("=" * 60)

home = Path.home()
hub_dir = Path(torch.hub.get_dir())

print("\nCorrect Location:")
if hub_dir.exists():
    size = sum(f.stat().st_size for f in hub_dir.rglob('*') if f.is_file())
    print(f"  torch.hub: {hub_dir} ({size / 1024 / 1024:.1f} MB)")

# Check bad location
bad_torch = home / '.cache' / 'torch'
print("\nBad Location (~/.cache/torch):")
if bad_torch.exists():
    size = sum(f.stat().st_size for f in bad_torch.rglob('*') if f.is_file())
    print(f"  EXISTS with {size / 1024 / 1024:.1f} MB")
    print("  (May be pre-existing content, check timestamps)")
else:
    print("  Does not exist - PERFECT!")

print("\n" + "=" * 60)
print("If downloads went to correct locations, the fix is working!")
print("=" * 60)

## Next Steps

- **Test_05:** Full verification of all cache locations and cleanup test