In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
import os
import zipfile
import urllib.request
import time

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Step 1: Download TinyImageNet dataset (subset of ImageNet)
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
data_dir = './data/tiny-imagenet-200'

if not os.path.exists(data_dir):
    print("Downloading TinyImageNet dataset...")
    urllib.request.urlretrieve(url, './tiny-imagenet-200.zip')

    # Extracting the dataset
    with zipfile.ZipFile('./tiny-imagenet-200.zip', 'r') as zip_ref:
        zip_ref.extractall('./data')
    print("Dataset downloaded and extracted.")

# Step 2: Define data transformations
input_size = 224  # Image size for ResNet input
batch_size = 16  # Reduced batch size to prevent memory issues

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Step 3: Load datasets
train_dir = os.path.join(data_dir, 'train')
full_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])

# Use only the first 80 classes
class_subset = 80
targets = torch.tensor([sample[1] for sample in full_dataset.samples])
indices = [i for i in range(len(targets)) if targets[i] < class_subset]
subset_dataset = torch.utils.data.Subset(full_dataset, indices)

# Split the subset into training and validation (80% train, 20% val)
train_size = int(0.8 * len(subset_dataset))
val_size = len(subset_dataset) - train_size
train_dataset, val_dataset = random_split(subset_dataset, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Step 4: Define ResNet-50 model
model = models.resnet50(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, class_subset)
model = model.to(device)

# Step 5: Define RPQ, mcache, and cached_weights
mcache = {}  # Cache for storing RPQ outputs
mcache_limit = 5000  # Static limit for the cache size
cached_weights = {}  # Cache for WS (Weight-Stationary) Dataflow
total_computations = 0  # Global counter for total computations

def rpq_function(input_tensor, random_matrix):
    """
    Generate RPQ signature and fetch/store cached outputs.
    - Counts computations during random projection.
    """
    global total_computations

    # Flatten the input tensor for random projection
    input_vector = input_tensor.view(input_tensor.size(0), -1)
    projected = torch.matmul(input_vector, random_matrix)
    total_computations += input_vector.numel() * random_matrix.size(1)  # Count matmul operations

    # Generate binary signatures for RPQ
    signatures = [tuple((proj > 0).int().tolist()) for proj in projected]

    # Check cache and fetch/store outputs
    outputs = []
    for signature in signatures:
        if signature in mcache:
            outputs.append(mcache[signature])  # Cache hit
        else:
            outputs.append(None)  # Cache miss
    return outputs, signatures

# Step 6: Training function with RPQ and WS Dataflow
def train_model_with_rpq_and_ws(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10):
    global total_computations
    # Initialize random projection matrix for RPQ
    random_matrix = torch.randn(224 * 224 * 3, 512, device=device)

    total_start_time = time.time()
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 20)

        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        epoch_start_time = time.time()

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Apply RPQ and check cache
            cache_outputs, signatures = rpq_function(inputs, random_matrix)

            optimizer.zero_grad()
            outputs = []

            # Process each input based on RPQ and WS caching
            for i, cached_output in enumerate(cache_outputs):
                if cached_output is None:  # Cache miss
                    for name, layer in model.named_children():
                        if isinstance(layer, nn.Conv2d):  # Only cache weights for Conv2D layers
                            if name not in cached_weights:
                                cached_weights[name] = layer.weight.clone().detach()

                            # Perform convolution with cached weights
                            output = nn.functional.conv2d(
                                inputs[i:i+1], cached_weights[name], bias=layer.bias,
                                stride=layer.stride, padding=layer.padding, dilation=layer.dilation, groups=layer.groups
                            )
                        elif name == "fc":  # Handle the fully connected layer
                            output = layer(inputs[i:i+1].view(inputs[i:i+1].size(0), -1))  # Flatten input
                        else:
                            # Process other layers normally
                            output = layer(inputs[i:i+1])

                        # Append the output for this input
                        outputs.append(output)

                        # Count computations
                        if isinstance(layer, nn.Conv2d) or name == "fc":
                            total_computations += inputs[i:i+1].numel() * layer.weight.numel()
                            break

                    # Update mcache with the new output
                    if len(mcache) < mcache_limit:
                        mcache[signatures[i]] = output.detach()
                else:
                    outputs.append(cached_output)  # Cache hit

            # Concatenate outputs to form a batch
            outputs = torch.cat(outputs, dim=0)

            # Compute loss and backpropagate
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_duration = time.time() - epoch_start_time
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.4f}')
        print(f'Epoch {epoch + 1} duration: {epoch_duration:.4f} seconds')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_epoch_acc = val_correct / val_total
        print(f'Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_epoch_acc:.4f}')

        scheduler.step()

    total_duration = time.time() - total_start_time
    print(f'Total training time with RPQ and WS: {total_duration:.4f} seconds')
    print(f"Total Computations during Training: {total_computations}")
    return total_duration


Using device: cuda
