## Importing modules

In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50
import numpy as np

import skimage
from skimage.color import rgb2hed, hed2rgb
import pywt
from PIL import Image
from sklearn.decomposition import PCA
import cv2

In [2]:
class H5Dataset(Dataset):
    def __init__(self, image_file, label_file, transform=None):
        self.transform = transform
        
        # Load data from the H5 file
        with h5py.File(image_file, 'r') as f:
            self.images = f['x'][:]
        with h5py.File(label_file, 'r') as f:
            self.labels = f['y'][:].reshape(-1, 1).astype(np.float32)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [3]:
class RGB2HED(torch.nn.Module):
    def __init__(self, mode=None):
        super(RGB2HED, self).__init__()
        self.mode = mode
    def forward(self, img):
        img = img.astype(np.float32) / 255.
        hed_img = rgb2hed(img) * 255.
        hed_img = np.tile(hed_img[:, :, -2:-1], reps=(1,1,3))
        return hed_img
    
        
class WaveletTransform(nn.Module):
    def __init__(self, wavelet='haar', threshold=20):
        super(WaveletTransform, self).__init__()
        self.wavelet = wavelet
        self.threshold = threshold
        
    def forward(self, img):
        grayscale_image = np.dot(img.astype(np.uint8), [0.299, 0.587, 0.114])
        
        # Step 2: Perform 2D wavelet decomposition
        coeffs = pywt.wavedec2(grayscale_image, wavelet=self.wavelet, level=2)
        cA, details = coeffs[0], coeffs[1:]
        
        # Step 3: Apply thresholding to detail coefficients
        def threshold_coeffs(coeffs, threshold):
            return [pywt.threshold(c, threshold, mode='soft') for c in coeffs] 
        
        
        details_thresh = [threshold_coeffs(detail, self.threshold) for detail in details]
        coeffs_thresh = [cA] + details_thresh
        
        # Step 4: Reconstruct the image
        compressed_image = pywt.waverec2(coeffs_thresh, wavelet=self.wavelet)
        compressed_image = np.clip(compressed_image, 0, 255).astype(np.uint8)
        compressed_image = np.tile(np.expand_dims(compressed_image, -1), (1,1,3))
        
        return compressed_image
    

class CLAHE(nn.Module):
    def __init__(self, mode=None):
        super(CLAHE, self).__init__()
        self.mode = mode
    def forward(self, image):
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2LAB)
        l_channel, a, b = cv2.split(lab_image)

        # Apply CLAHE to the L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_channel = clahe.apply(l_channel)

        # Merge and convert back to RGB
        lab_image = cv2.merge((l_channel, a, b))
        return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
    

class Macenko(nn.Module):
    def __init__(self):
        super(Macenko, self).__init__()
    
    def forward(self, image):
        """Normalize H&E stained images using Macenko method."""
        # Reshape image to 2D
        h, w, c = image.shape
        image_flat = image.reshape((-1, c))

        # PCA for stain separation
        pca = PCA(n_components=c)
        pca.fit(image_flat)
        stains = pca.components_

        # Normalize to intensity ranges
        norms = np.sqrt(np.sum(stains**2, axis=0))
        normalized_stains = stains / norms
        normalized_image = np.dot(image_flat, normalized_stains.T)
        
        # Scale back and reshape
        return normalized_image.reshape((h, w, c))
    
class Opening(nn.Module):
    def __init__(self):
        super(Opening, self).__init__()
        
    def forward(self, image):
        return skimage.morphology.opening(image)


# transform = transforms.Compose([
#     # Opening(),
#     # CLAHE(),
#     # Macenko(),
#     # WaveletTransform(),
#     transforms.ToPILImage(),
#     # transforms.Resize((96, 96)),
#     transforms.ToTensor(),
#     # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# ])

transform = transforms.Compose([transforms.ToPILImage(),
                    transforms.ColorJitter(brightness=.5, saturation=.25,
                                        hue=.1, contrast=.5),
                    transforms.RandomAffine(10, (0.05, 0.05), fill=255),
                    transforms.RandomHorizontalFlip(.5),
                    transforms.RandomVerticalFlip(.5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                                        [0.27210504, 0.31001145, 0.2918652])])
val_transform = transforms.Compose([transforms.ToPILImage(),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                                                                      [0.27210504, 0.31001145, 0.2918652])])

In [4]:
# Load datasets
train_dataset = H5Dataset(image_file='../../pcam/training_split.h5', 
                          label_file='../../Labels/Labels/camelyonpatch_level_2_split_train_y.h5', 
                          transform=transform)
val_dataset = H5Dataset(image_file='../../pcam/validation_split.h5', 
                        label_file='../../Labels/Labels/camelyonpatch_level_2_split_valid_y.h5',
                        transform=val_transform)

test_dataset = H5Dataset(image_file='../../pcam/test_split.h5', 
                        label_file='../../Labels/Labels/camelyonpatch_level_2_split_test_y.h5',
                        transform=val_transform)

# Create dataloaders
bs = 128
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=4)

In [5]:
import sys

import numpy as np
import torch
from torch.nn import functional as F
from torch.nn.modules.utils import _single, _pair, _triple


# from torch._jit_internal import weak_module, weak_script_method


# @weak_module
class PolarConvNd(torch.nn.modules.conv._ConvNd):
    def __init__(self, in_channels=1, out_channels=1, kernel_size=3, dimensions=2, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        self.init_kernel_size = kernel_size
        assert kernel_size % 2 == 1, 'expected kernel size to be odd, found %d' % kernel_size
        self.init_dimensions = dimensions

        self.base_vectors = torch.from_numpy(self.build_base_vectors()).float().to(device)
        self.true_base_vectors_shape = self.base_vectors.shape
        self.base_vectors = self.base_vectors.view(self.true_base_vectors_shape[0],
                                                   np.prod(self.true_base_vectors_shape[1:]).astype(int))

        inferred_kernel_size = self.true_base_vectors_shape[0]
        _kernel_size = _single(inferred_kernel_size)
        _stride = _single(stride)
        _padding = _single(padding)
        _dilation = _single(dilation)
        super(PolarConvNd, self).__init__(
            in_channels, out_channels, _kernel_size, _stride, _padding, _dilation,
            False, _single(0), groups, bias, padding_mode)

        if dimensions == 2:
            self.reconstructed_stride = _pair(stride)
            self.reconstructed_padding = _pair(padding)
            self.reconstructed_dilation = _pair(dilation)
            self.reconstructed_conv_op = F.conv2d
        elif dimensions == 3:
            self.reconstructed_stride = _triple(stride)
            self.reconstructed_padding = _triple(padding)
            self.reconstructed_dilation = _triple(dilation)
            self.reconstructed_conv_op = F.conv3d
        else:
            raise ValueError('dimension %d not supported' % dimensions)

    def build_base_vectors(self):
        kernel_size = self.init_kernel_size
        middle = kernel_size // 2
        dimensions = self.init_dimensions

        base_vectors = []
        # Burning phase: determine the number of base vectors
        unique_distances = []
        if dimensions == 2:
            for i in range(kernel_size):
                for j in range(kernel_size):
                    i_ = abs(i - middle)
                    j_ = abs(j - middle)
                    unique_distances.append(int(i_ * i_ + j_ * j_))
        elif dimensions == 3:
            for i in range(kernel_size):
                for j in range(kernel_size):
                    for k in range(kernel_size):
                        i_ = abs(i - middle)
                        j_ = abs(j - middle)
                        k_ = abs(k - middle)
                        unique_distances.append(int(i_ * i_ + j_ * j_ + k_ * k_))
        unique_distances, distances_counts = np.unique(unique_distances, return_counts=True)
        unique_distances = np.sort(unique_distances)
        print(*zip(unique_distances, distances_counts), len(unique_distances))

        for unique_distance, n in zip(unique_distances, distances_counts):  # number of base vectors
            base_vector = np.zeros([kernel_size] * dimensions)
            if dimensions == 2:
                for i in range(kernel_size):
                    for j in range(kernel_size):
                        i_ = abs(i - middle)
                        j_ = abs(j - middle)
                        if int(i_ * i_ + j_ * j_) == unique_distance:
                            base_vector[i, j] = 1./n
            elif dimensions == 3:
                for i in range(kernel_size):
                    for j in range(kernel_size):
                        for k in range(kernel_size):
                            i_ = abs(i - middle)
                            j_ = abs(j - middle)
                            k_ = abs(k - middle)
                            if int(i_ * i_ + j_ * j_ + k_ * k_) == unique_distance:
                                base_vector[i, j, k] = 1./n
            base_vectors.append(base_vector)
        base_vectors = np.asarray(base_vectors)
        return base_vectors

    # @weak_script_method
    def forward(self, input):
        weight_size = self.weight.shape
        weight = torch.mm(self.weight.view(np.prod(weight_size[:-1]), weight_size[-1]), self.base_vectors) \
            .view(*weight_size[:-1], *self.true_base_vectors_shape[1:])
        return self.reconstructed_conv_op(input, weight, self.bias, self.reconstructed_stride,
                                          self.reconstructed_padding, self.reconstructed_dilation, self.groups)


    def __repr__(self):
        return ('PolarConv%dd' % self.init_dimensions) + '(' + self.extra_repr() + ')'

In [6]:
import torch.nn as nn
from torch.hub import load_state_dict_from_url

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, conv_type='classical', kernel_size=3):
    """3x3 convolution with padding"""
    if conv_type.lower() == 'classical':
        return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                         padding=(kernel_size - 1) // 2, groups=groups, bias=False, dilation=dilation)
    elif conv_type.lower() == 'polar':
        return PolarConvNd(in_planes, out_planes, kernel_size=kernel_size, dimensions=2,
                           stride=stride, padding=(kernel_size - 1) // 2, groups=groups, bias=False, dilation=dilation)
    raise ValueError('unknow conv layer type %s' % conv_type)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, conv_type='classical', kernel_size=3):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, conv_type=conv_type, kernel_size=kernel_size)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, conv_type=conv_type, kernel_size=kernel_size)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, conv_type='classical', kernel_size=3):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation, conv_type=conv_type, kernel_size=kernel_size)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, conv_type, kernel_size, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(conv_type, kernel_size, block, 64, layers[0])
        self.layer2 = self._make_layer(conv_type, kernel_size, block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(conv_type, kernel_size, block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(conv_type, kernel_size, block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, PolarConvNd):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, conv_type, kernel_size, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer,
                            conv_type=conv_type, kernel_size=kernel_size))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer,
                                conv_type=conv_type, kernel_size=kernel_size))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x


def _resnet(arch, conv_type, kernel_size, block, layers, pretrained, progress, **kwargs):
    model = ResNet(conv_type, kernel_size, block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', conv_type, kernel_size, BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', conv_type, kernel_size, BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', conv_type, kernel_size, Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', conv_type, kernel_size, Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152', conv_type, kernel_size, Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnext50_32x4d(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNeXt-50 32x4d model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', conv_type, kernel_size, Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def resnext101_32x8d(conv_type, kernel_size, pretrained=False, progress=True, **kwargs):
    """Constructs a ResNeXt-101 32x8d model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', conv_type, kernel_size, Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)

In [7]:
# Initialize model, loss function, and optimizer
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
polar = True
model = resnet50(conv_type='classical' if not polar else 'polar',
                                                         kernel_size=3,
                                                         num_classes=1)
model = model.to(device)

# criterion = nn.CrossEntropyLoss()
criterion = torch.nn.BCEWithLogitsLoss()
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.int64(0), np.int64(1)) (np.int64(1), np.int64(4)) (np.int64(2), np.int64(4)) 3
(np.

In [8]:
# Training and validation loops
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            train_loss += loss.item()
            train_correct += ((outputs > .5).float() == labels).sum().item()
            train_total += labels.size(0)
            # train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                val_loss += loss.item()
                val_correct += ((outputs > .5).float() == labels).sum().item()
                val_total += labels.size(0)
                # val_correct += predicted.eq(labels).sum().item()
                
        # Test phase
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Metrics
                test_loss += loss.item()
                test_correct += ((outputs > .5).float() == labels).sum().item()
                test_total += labels.size(0)
                # test_correct += predicted.eq(labels).sum().item()
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100 * train_correct/train_total:.2f}%")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100 * val_correct/val_total:.2f}%")
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {100 * test_correct/test_total:.2f}%\n\n")

# Train and validate the model
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=10)

Epoch 1/10
Train Loss: 0.5225, Train Acc: 73.47%
Val Loss: 0.5100, Val Acc: 78.31%
Test Loss: 0.4795, Test Acc: 79.83%


Epoch 2/10
Train Loss: 0.4545, Train Acc: 77.82%
Val Loss: 0.4251, Val Acc: 76.21%
Test Loss: 0.4416, Test Acc: 75.32%


Epoch 3/10
Train Loss: 0.4166, Train Acc: 79.99%
Val Loss: 0.4203, Val Acc: 77.57%
Test Loss: 0.4283, Test Acc: 76.58%


Epoch 4/10
Train Loss: 0.3806, Train Acc: 82.22%
Val Loss: 0.4063, Val Acc: 81.64%
Test Loss: 0.4090, Test Acc: 81.49%


Epoch 5/10
Train Loss: 0.3452, Train Acc: 84.38%
Val Loss: 0.3623, Val Acc: 81.41%
Test Loss: 0.3793, Test Acc: 81.48%


Epoch 6/10
Train Loss: 0.3197, Train Acc: 85.70%
Val Loss: 0.4405, Val Acc: 83.51%
Test Loss: 0.4427, Test Acc: 83.87%


Epoch 7/10
Train Loss: 0.3016, Train Acc: 86.72%
Val Loss: 0.3733, Val Acc: 81.49%
Test Loss: 0.4092, Test Acc: 79.17%


Epoch 8/10
Train Loss: 0.2908, Train Acc: 87.32%
Val Loss: 0.3758, Val Acc: 82.79%
Test Loss: 0.3810, Test Acc: 80.45%


Epoch 9/10
Train Loss: 0.2813, T

In [9]:
# le = 1e-3, wd = 1e-6

# Epoch 1/10
# Train Loss: 0.6988, Train Acc: 50.10%
# Val Loss: 0.6933, Val Acc: 50.05%
# Test Loss: 0.6933, Test Acc: 50.02%


# Epoch 2/10
# Train Loss: 0.6932, Train Acc: 50.00%
# Val Loss: 0.6932, Val Acc: 50.05%
# Test Loss: 0.6932, Test Acc: 50.02%


# Epoch 3/10
# Train Loss: 0.6932, Train Acc: 50.00%
# Val Loss: 0.6933, Val Acc: 50.05%
# Test Loss: 0.6933, Test Acc: 50.02%


# Epoch 4/10
# Train Loss: 0.6932, Train Acc: 50.00%
# Val Loss: 0.6933, Val Acc: 50.05%
# Test Loss: 0.6934, Test Acc: 50.02%


# Epoch 5/10
# Train Loss: 0.6932, Train Acc: 50.00%
# Val Loss: 0.6932, Val Acc: 50.05%
# Test Loss: 0.6931, Test Acc: 50.02%


# Epoch 6/10
# Train Loss: 0.6932, Train Acc: 50.00%
# Val Loss: 0.6932, Val Acc: 50.05%
Test Loss: 0.6932, Test Acc: 50.02%

SyntaxError: invalid syntax (1124503448.py, line 36)