In [5]:
# ✅ Import required libraries
!pip install torch torchvision torchaudio
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random

# ✅ Check device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on {device}")

# ✅ Data Preparation (Example: MVTec AD dataset, or dummy MNIST for quick test)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load dataset (Replace MNIST with MVTec dataset if available)
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# ✅ Feature Extractor Model (ResNet18)
model = torchvision.models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # for grayscale images
model.fc = nn.Linear(512, 2)  # Normal / Anomaly
model = model.to(device)

# ✅ Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# ✅ Training loop (few epochs)
num_epochs = 2
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.to(device)
        
        # Generate synthetic anomaly labels (e.g., numbers >=5 are anomalies)
        labels = (labels >= 5).long().to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"✅ Epoch [{epoch+1}/{num_epochs}] finished, Loss: {loss.item():.4f}")

print("✅ Training completed!")

# ✅ Test a sample image
model.eval()
sample_img, label = random.choice(train_dataset)
sample_img = sample_img.unsqueeze(0).to(device)

with torch.no_grad():
    output = model(sample_img)
    predicted = torch.argmax(output, dim=1).item()

# ✅ Visualize the result
plt.imshow(sample_img.squeeze(0).squeeze(0).cpu(), cmap='gray')
plt.title(f"Predicted: {'Anomaly' if predicted else 'Normal'}")
plt.axis('off')
plt.show()

Collecting torch
  Downloading torch-2.7.0-cp312-cp312-win_amd64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp312-cp312-win_amd64.whl.metadata (6.3 kB)
Collecting torchaudio
  Downloading torchaudio-2.7.0-cp312-cp312-win_amd64.whl.metadata (6.7 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Downloading torch-2.7.0-cp312-cp312-win_amd64.whl (212.5 MB)
   ---------------------------------------- 0.0/212.5 MB ? eta -:--:--
   ---------------------------------------- 0.3/212.5 MB ? eta -:--:--
   ---------------------------------------- 0.5/212.5 MB 1.3 MB/s eta 0:02:45
   ---------------------------------------- 0.5/212.5 MB 1.3 MB/s eta 0:02:45
   ---------------------------------------- 1.0/212.5 MB 1.3 MB/s eta 0:02:40
   ---------------------------------------- 1.3/212.5 MB 1.3 MB/s eta 0:02:44
   ---------------------------------------- 1.6/212.5 MB 1.2 MB/s eta 0:03:02
   ----------------------

ERROR: Exception:
Traceback (most recent call last):
  File "C:\Users\salma\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 438, in _error_catcher
    yield
  File "C:\Users\salma\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 561, in read
    data = self._fp_read(amt) if not fp_closed else b""
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\salma\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 527, in _fp_read
    return self._fp.read(amt) if amt is not None else self._fp.read()
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\salma\anaconda3\Lib\site-packages\pip\_vendor\cachecontrol\filewrapper.py", line 98, in read
    data: bytes = self.__fp.read(amt)
                  ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\salma\anaconda3\Lib\http\client.py", line 479, in read
    s = self.fp.read(amt)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\salma\anaconda3\Lib\socket.py", line 720, in readinto
    return self._sock.recv_into(b)
           ^

ModuleNotFoundError: No module named 'torch'