In [None]:
!pip install torch
!pip install torchvision
!pip install Pillow
!pip install psutil

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from typing import Callable, Optional
from PIL import Image

from psutil import virtual_memory
from torchvision import datasets
from torch.utils.data import DataLoader


#@title Show Runtime Information
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)

ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')


class FilteredDataLoader(DataLoader):
    def __iter__(self):
        for batch in super().__iter__():
            inputs, targets = batch
            filtered_indices = [i for i, img in enumerate(inputs) if img is not None]
            filtered_inputs = inputs[filtered_indices]
            filtered_targets = targets[filtered_indices]
            yield filtered_inputs, filtered_targets


class CustomImageFolder(datasets.ImageFolder):
    def __init__(self, root: str, transform: Optional[Callable] = None):
        def is_valid_file(path: str) -> bool:
            try:
                with open(path, "rb") as f:
                    Image.open(f).verify()
                return True
            except (OSError, IOError):
                print(f"Warning: Could not read image {path}, skipping.")
                return False

        super().__init__(root, transform=transform, is_valid_file=is_valid_file)

    def _load_image(self, path: str) -> Optional[Image.Image]:
        try:
            return Image.open(path).convert("RGB")
        except OSError:
            print(f"Warning: Could not read image {path}, skipping.")
            return None

    def _filter_corrupted_images(self, root: str):
        samples = []
        targets = []
        for i, (path, target) in enumerate(datasets.ImageFolder(root).samples):
            img = self._load_image(path)
            if img is not None:
                samples.append((path, target))
                targets.append(target)
        return samples, targets

    def __getitem__(self, index: int):
        path, target = self.samples[index]
        img = self._load_image(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.samples)


# Load and preprocess the dataset
print('Loading and preprocessing the dataset...')
model_path = '/your/path/datasets/gpt-4-cnn-starter.pth'
image_path = '/your/path/image.png'

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_data = CustomImageFolder(root='/your/path/train', transform=transform)
test_data = CustomImageFolder(root='/your/path/test', transform=transform)

train_loader = FilteredDataLoader(train_data, batch_size=32, shuffle=True)
test_loader = FilteredDataLoader(test_data, batch_size=32, shuffle=False)


# Define the CNN architecture
print('Defining the CNN architecture...')
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 56 * 56, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


num_classes = len(train_data.classes)
model = SimpleCNN(num_classes)

# Set up loss function and optimizer
print('Setting up loss function and optimizer...')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model
print('Training the model...')
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}')

# Evaluate the model
print('Evaluating the model...')
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i, (inputs, targets) in enumerate(test_loader):
        print(f'Evaluating image batch {i+1}/{len(test_loader)}')
        for j in range(inputs.size(0)):
            filename = test_data.samples[i*inputs.size(0) + j][0]
            outputs = model(inputs[j:j+1])
            _, predicted = outputs.max(1)
            total += 1
            correct += predicted.eq(targets[j:j+1]).sum().item()

accuracy = correct / total
print(f'Test accuracy: {accuracy:.2%}')

# Save the trained model's state dictionary
torch.save(model.state_dict(), model_path)

# Create a model instance and load the state dictionary
model = SimpleCNN(num_classes)
model.load_state_dict(torch.load(model_path))

# Print the model's architecture
print(model)

# Print individual layer weights
for name, param in model.named_parameters():
    print(f"{name}: {param.size()}")

# Create a dictionary to map class indices to labels
idx_to_class = {idx: label for label, idx in train_data.class_to_idx.items()}

# Print the mapping
print(idx_to_class)

# Example: Get the label for class index 0
class_index = 0
class_label = idx_to_class[class_index]
print(f"Class index {class_index} corresponds to label '{class_label}'")


# Load an image and preprocess it
print('Loading and preprocessing an image...')
image = Image.open(image_path).convert('RGB')

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

preprocessed_image = transform(image).unsqueeze(0)  # Add batch dimension

# Run the model on the preprocessed image
print('Running the model on the image...')
model.eval()
with torch.no_grad():
    outputs = model(preprocessed_image)
    _, predicted_class_index = outputs.max(1)

predicted_label = idx_to_class[predicted_class_index.item()]
print(f"The predicted label for the image is '{predicted_label}'.")