<a href="https://colab.research.google.com/github/annaGrd/Internship/blob/main/Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [None]:
!pip install wandb -qU
!pip install einops
!pip install pytorch_lightning



In [None]:
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import wandb
import cv2
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image
from einops import rearrange
from matplotlib.ticker import MaxNLocator
import argparse
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.stats
from skimage.util import random_noise
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
chemin = '/content/drive/My Drive/Stage/StageAnna/'
if not os.path.isdir(chemin):
    os.mkdir(chemin)
os.chdir(chemin)

Mounted at /content/drive/


# Parser

In [None]:
from opts import parser
from utils import dataloaders
# args, unknown = parser.parse_known_args()

def arguments():
    parser = argparse.ArgumentParser(description="Classifying natural images with complex-valued neural networks")

    parser.add_argument('--filename', type=str, default="test")
    parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float)
    parser.add_argument('--batch-size', default=256, type=int)
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--noise_type', default=None)
    parser.add_argument('--load', default=False, type=bool)
    parser.add_argument('--save', default=False, type=bool)

    parser.add_argument('--model', type=str, default='AlexNet_complex')

    return parser.parse_args("")

args = arguments()

chemin = f"/content/drive/My Drive/Stage/StageAnna/Image/test_VGG16"
if not os.path.isdir(chemin):
    os.mkdir(chemin)
if not os.path.isdir(chemin+f"/{args.model}"):
    os.mkdir(chemin+f"/{args.model}")
    #os.mkdir(chemin+f"/{args.model}/Spatial")
    #os.mkdir(chemin+f"/{args.model}/Polar")

# Device configuration, Hyper-parameters and Classes Names



In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# login with wandb
wandb.login()

# Hyper-parameters
# num_epochs = 4
batch_size = args.batch_size
learning_rate = args.lr

classes = ('plane', 'car', 'bird', 'cat',
          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

[34m[1mwandb[0m: Currently logged in as: [33manna-grd[0m ([33mcomplex-dnns[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Choice of architecture and loaders

In [None]:
if args.model == 'AlexNet_real_small':
    from models.AlexNet_real_small import AlexNet
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()
elif args.model == 'AlexNet_complex_bio':
    from models.AlexNet_complex_bio import AlexNet
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'AlexNet_complex':
    loader_path = "/content/drive/My Drive/Stage/StageAnna/Image/_trained_model"
    from models.AlexNet_complex import ComplexWeigth_AlexNet, AlexNet
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'VGG11_complex':
    from models.VGG_complex import VGG11
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'VGG16_complex':
    from models.VGG_complex import VGG16
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'VGG11_real':
    from models.VGG_real import VGG11
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()
elif args.model == 'VGG16_real':
    from models.VGG_real import VGG16
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()
elif args.model == 'VGG13_complex':
    from models.VGG_complex import VGG13
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'VGG13_real':
    from models.VGG_real import VGG13
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()
elif args.model == 'VGG19_complex':
    from models.VGG_complex import VGG19
    train_loader, val_loader = dataloaders.iget_train_data()
    test_loader = dataloaders.iget_test_data()
elif args.model == 'VGG19_real':
    from models.VGG_real import VGG19
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()
else : # args.model == 'AlexNet_real':
    from models.AlexNet_real import AlexNet
    loader_path = "/content/drive/My Drive/Stage/StageAnna/Image/_trained_model"
    train_loader, val_loader = dataloaders.RGBtrain_data()
    test_loader = dataloaders.RGBtest_data()

Files already downloaded and verified
Files already downloaded and verified


# Plot d'image

In [None]:
def plot_mean_image(c1, mx1, c2, mx2, c3, c4, c5, mx5, idx):

    layers = [(c1, "Convolution 1"), (mx1, "MaxPool 1"), (c2, "Convolution 2"),
        (mx2, "MaxPool 2"),(c3, "Convolution 3"), (c4, "Convolution 4"),
        (c5, "Convolution 5"), (mx5, "MaxPool 5")]

    fig, axs = plt.subplots(2, 4, squeeze=False, layout='tight', figsize=[80, 40])
    nb_layer = -1

    for (layer, name) in layers:   # each operation on a single image

        nb_layer += 1
        ax = axs.flat[nb_layer]
        edge = layer.shape[-1]

        phase = layer[idx,:,:,:].angle()
        phase = phase.detach().cpu().numpy()

        mean = scipy.stats.circmean(phase, axis=0, high=np.pi, low=-np.pi)

        magnitude = layer[idx,:,:,:].abs()
        magnitude = magnitude.mean(0, True)
        magnitude = magnitude.detach().cpu().numpy()
        if magnitude.max():
            magnitude = magnitude / magnitude.max()

        levels = MaxNLocator(nbins=15).tick_values(phase.min(), phase.max())

        y, x = np.mgrid[slice(edge), slice(edge)]

        cmap = plt.colormaps['hsv']
        ax.set(aspect='equal', adjustable='box')

        im = ax.pcolormesh(x, y, mean, cmap=cmap, vmin=-np.pi, vmax=np.pi, alpha=magnitude) #, norm=norm
        ax.title.set_text(name)

    plt.savefig(chemin+f"/{args.model}/Spatial/mean.png")
    plt.clf()
    #plt.show()

In [None]:
def plot_image(c1, mx1, c2, mx2, c3, c4, c5, mx5, idx):

    layers = [(c1, "Convolution 1")]
    # , (mx1, "MaxPool 1"), (c2, "Convolution 2"),
        # (mx2, "MaxPool 2"),]
    """
    (c3, "Convolution 3"), (c4, "Convolution 4"),
    (c5, "Convolution 5"), (mx5, "MaxPool 5")]
    """

    for (layer, name) in layers:   # each operation on a single image

        features = len(layer[idx,:, 0,0])
        row = 8
        column = features // 8
        fig, axs = plt.subplots(row, column, sharex=True, sharey=True, squeeze=False, layout='tight', figsize=[5*column,5*row])

        for feature in range(features):
            edge = len(layer[idx, feature,0,:])
            phase = layer[idx,feature,:,:].angle()
            magnitude = layer[idx,feature,:,:].abs()

            phase = phase.detach().cpu().numpy()
            magnitude = magnitude.detach().cpu().numpy()
            if magnitude.max():
                magnitude = magnitude / magnitude.max()
            levels = MaxNLocator(nbins=15).tick_values(phase.min(), phase.max())

            y, x = np.mgrid[slice(edge), slice(edge)]

            cmap = plt.colormaps['hsv']
            norm = mpl.colors.Normalize(-np.pi, np.pi)
            #norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
            ax = axs.flat[feature]
            ax.set(aspect='equal', adjustable='box')
            im = ax.pcolormesh(x, y, phase, cmap=cmap, norm=norm, alpha=magnitude) #, norm=norm #vmin=-np.pi, vmax=np.pi


        plt.title(f"Operation : {name}")
        plt.savefig(chemin+f"/{args.model}/{name} spatial.png") #chemin+f"/{args.model}/Spatial/{name}.png"
        plt.clf()
        #plt.show()

# Plot de phase

In [None]:
def polar_mean_plot(c1, mx1, c2, mx2, c3, c4, c5, mx5, avg, c6, c7, c8, idx):

    layers = [(c1, "Convolution 1"), (mx1, "MaxPool 1"), (c2, "Convolution 2"),
        (mx2, "MaxPool 2"),(c3, "Convolution 3"), (c4, "Convolution 4"),
        (c5, "Convolution 5"), (mx5, "MaxPool 5"), (avg, "AvgPool"),
        (c6, "Convolution 6"), (c7, "Convolution 7") , (c8, "Convolution 8")]


    fig, axs = plt.subplots(3, 4, figsize=[80, 60], subplot_kw={'projection': 'polar'}, squeeze=False, layout='tight') #, sharex=True, sharey=True
    nb_layer = -1

    for (layer, name) in layers:   # each operation on a single image

        nb_layer += 1
        features = len(layer[idx,:, 0,0])
        phase = layer[idx,:,:,:].angle()
        phase = phase.detach().cpu().numpy()
        phase = scipy.stats.circmean(phase, axis=0, high=np.pi, low=-np.pi)

        magnitude = layer[idx,:,:,:].abs()
        magnitude = magnitude.mean(0, True)
        magnitude = magnitude.detach().cpu().numpy()

        color = phase
        cmap = plt.colormaps['hsv']
        norm = mpl.colors.Normalize(-np.pi, np.pi)

        ax = axs.flat[nb_layer]
        ax.set_ylim([0, magnitude.max()])
        ax.set(aspect='equal', adjustable='box')
        ax.scatter(phase.flatten(), magnitude.flatten(), c=color.flatten(), cmap=cmap, norm=norm) #, c=phase.flatten
        ax.title.set_text(name)

    plt.savefig(chemin+f"/{args.model}/Polar/mean.png")
    plt.clf()
    #plt.show()

In [None]:
def polar_plot(c1, mx1, c2, mx2, c3, c4, c5, mx5, avg, c6, c7, c8, idx):

    layers = [(c1, "Convolution 1")]
    """
    , (mx1, "MaxPool 1"), (c2, "Convolution 2"),]
    (mx2, "MaxPool 2"),(c3, "Convolution 3"),
    (c4, "Convolution 4"),
    (c5, "Convolution 5"), (mx5, "MaxPool 5"), (avg, "AvgPool"),
    (c6, "Convolution 6"), (c7, "Convolution 7") , (c8, "Convolution 8")]"""

    for (layer, name) in layers:   # each operation on a single image

        features = len(layer[idx,:, 0,0])
        row = 8
        column = features // 8
        fig, axs = plt.subplots(row, column, sharex=True, sharey=True, subplot_kw={'projection': 'polar'}, squeeze=False, layout='tight', figsize=[5*column,5*row])
        for feature in range(features):

            phase = layer[idx,feature,:,:].angle()
            phase = phase.detach().cpu().numpy()
            magnitude = layer[idx,feature,:,:].abs()
            magnitude = magnitude.detach().cpu().numpy()

            color = phase
            cmap = plt.colormaps['hsv']
            norm = mpl.colors.Normalize(-np.pi, np.pi)

            ax = axs.flat[feature]
            ax.set_ylim([0, magnitude.max()])
            ax.set(aspect='equal', adjustable='box')
            ax.scatter(phase.flatten(), magnitude.flatten(), c=color.flatten(), s=100, norm = norm, cmap='hsv') #, c=phase.flatten linewidths=0

        plt.title(f"Operation : {name}")
        plt.savefig(chemin+f"/{args.model}/{name} polar.png") #chemin+f"/{args.model}/Polar/{name}.png"
        plt.clf()
        #plt.show()

# Noise functions

In [None]:
def add_fog(images, intensity=0.5):
    shape = False
    if len(images.shape) == 3 :
        images = rearrange(images, 'b c h -> h b c')
        images = images[None,...]
        shape = True
    for i in range(images.shape[0]):
        image = images[i]
        #image = image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        pil_image = Image.fromarray(image)

        # Generate fog effect
        w, h = image.shape[:2]
        x, y = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
        fog = (x * y * intensity).astype(np.uint8)
        fog_image = Image.fromarray(fog, mode='L')
        merged = Image.merge("RGB", (fog_image, fog_image, fog_image))

        # Blend the original image with the fog effect
        foggy_image = Image.blend(pil_image, merged, intensity)

        # Convert back to PyTorch tensor
        foggy_image = np.array(foggy_image).astype(np.float32)
        foggy_tensor = torch.from_numpy(foggy_image) #.permute(2, 0, 1)

        images[i] = foggy_tensor
    if shape :
        images.squeeze_(0)
        images = rearrange(images, 'b c h -> c h b')

    return images

In [None]:
def add_snow(image, snow_intensity=0.1):
    w, h = image.size
    snow = np.random.rand(h, w, 3) < snow_intensity
    snow = (snow * 255).astype(np.uint8)
    snow_image = Image.fromarray(snow, mode='RGB')
    return Image.blend(image, snow_image, snow_intensity)

In [None]:
def add_motion_blur(image, kernel_size=15, angle=45):
    # Créer un kernel de convolution pour le flou de mouvement
    kernel = np.zeros((kernel_size, kernel_size))
    kernel[int((kernel_size-1)/2), :] = np.ones(kernel_size)
    kernel = cv2.warpAffine(kernel, cv2.getRotationMatrix2D((kernel_size/2-0.5, kernel_size/2-0.5), angle, 1.0), (kernel_size, kernel_size))
    kernel = kernel / kernel_size

    # Appliquer le kernel à l'image
    blurred = cv2.filter2D(image, -1, kernel)
    return blurred

In [None]:
def add_glass_blur(image, kernel_size=3, iterations=5):
    def get_random_shift(x, max_shift):
        return x + np.random.randint(-max_shift, max_shift + 1)

    blurred = image.copy()
    for _ in range(iterations):
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                # Randomly select a point in the kernel
                shift_x = get_random_shift(i, kernel_size // 2)
                shift_y = get_random_shift(j, kernel_size // 2)
                # Ensure the points are within the image boundaries
                shift_x = np.clip(shift_x, 0, image.shape[0] - 1)
                shift_y = np.clip(shift_y, 0, image.shape[1] - 1)
                # Swap the pixel values
                blurred[i, j] = image[shift_x, shift_y]
    return blurred

In [None]:
def add_zoom_blur(image, zoom_factor=1.2, iterations=10):
    h, w, c = image.shape
    center_x, center_y = w // 2, h // 2

    # Create an empty image to accumulate the results
    zoom_blur_image = np.zeros_like(image, dtype=np.float32)

    # Apply zoom blur by scaling the image multiple times
    for i in range(iterations):
        scale = zoom_factor ** i
        scaled_image = cv2.resize(image, (0, 0), fx=scale, fy=scale)

        # Calculate the coordinates to center the scaled image
        scaled_h, scaled_w, _ = scaled_image.shape
        top = (scaled_h - h) // 2
        left = (scaled_w - w) // 2
        scaled_image = scaled_image[top:top+h, left:left+w]

        # Blend the scaled image into the accumulated result
        alpha = 1.0 / (i + 1)
        zoom_blur_image = cv2.addWeighted(zoom_blur_image, 1.0 - alpha, scaled_image, alpha, 0)

    # Clip the values to the valid range and convert back to uint8
    zoom_blur_image = np.clip(zoom_blur_image, 0, 255).astype(np.uint8)

    return zoom_blur_image

In [None]:
def add_defocus_noise(image, sigma=1.0):
    """
    Add defocused noise to a PyTorch tensor image.

    Args:
        image (torch.Tensor): Input image tensor of shape (C, H, W).
        sigma (float): Standard deviation of the Gaussian blur kernel.

    Returns:
        torch.Tensor: Image tensor with added defocused noise.
    """
    # Convert tensor to numpy array (H, W, C) and add noise
    image_np = image.permute(1, 2, 0).cpu().numpy()

    # Create defocus kernel (Gaussian blur)
    defocus_kernel_size = int(sigma * 4) | 1  # ensure it's odd
    defocus_blur = cv2.GaussianBlur(image_np, (defocus_kernel_size, defocus_kernel_size), sigma)

    # Add Gaussian noise
    noise = np.random.normal(0, 1, image_np.shape)
    defocus_noisy = defocus_blur + noise

    # Clip values to [0, 255] and convert back to torch tensor
    defocus_noisy = np.clip(defocus_noisy, 0, 255).astype(np.uint8)
    defocus_noisy = torch.from_numpy(defocus_noisy).permute(2, 0, 1).float()

    return defocus_noisy

# Example usage:
# Assuming you have an image tensor `input_image` of shape (C, H, W)
# input_image = torch.randn(3, 256, 256)  # Example random input image
# noisy_image = add_defocus_noise(input_image, sigma=1.0)

# Note: Ensure you have OpenCV installed (`pip install opencv-python`) for GaussianBlur function.


#RGB image

In [None]:
def rgb_plot(rgb_loader, index, noise_level, noise_type):
    train_features, train_labels = next(iter(rgb_loader))
    image = train_features[index]
    image = rearrange(image, 'b c h -> c h b ')
    if noise_type == 'gaussian':
        image = torch.tensor(random_noise(image.cpu(), mode='gaussian', mean=noise_level))
    elif noise_type == 's&p':
        image = torch.tensor(random_noise(image.cpu(), mode='s&p', amount=noise_level))
    elif noise_type == 'speckle':
        print(noise_level)
        image = torch.tensor(random_noise(image.cpu(), mode='speckle', mean=noise_level),
                              dtype=torch.float).to(device)
    elif noise_type == 'localvar':
        noise_tensor = torch.ones(image.shape)*noise_level
        image = torch.tensor(random_noise(image.cpu(), mode='localvar', local_vars=noise_tensor))
    elif noise_type == 'fog':
        image = torch.tensor(add_fog(image.cpu(), intensity=noise_level),
                                dtype=torch.float).to(device)
    elif noise_type == 'poisson':
        image = torch.tensor(poisson_noise(image.cpu(), lmba=noise_level, clip=True))

    label = train_labels[index]
    plt.title(f"Label: {classes[label]}, noise {noise_type}, level {noise_level}")
    plt.savefig(chemin+f"/{args.model}/{index}_{noise_type}_{noise_level}.png")
    plt.clf()

# Training

In [None]:
def training(model, num_epochs, epoch, train_loader, optimizer, criterion):
    n_total_steps = len(train_loader)
    model.train()

    run_loss = 0.0
    cnt = 0
    total = 0.0
    correct = 0

    for i, (images, labels) in enumerate(train_loader): # to get all the different batches

        images, labels = images.to(device), labels.to(device)

        # Forward pass
        #outputs = model(images)[-1]
        outputs = model(images)
        outputs_magnitude = outputs.abs()
        loss = criterion(outputs_magnitude, labels)

        run_loss += loss.item()
        total += labels.size(0)
        cnt += 1
        _, predicted = torch.max(outputs_magnitude.data, 1)
        correct += (predicted == labels).sum().item()

        # Backward and optimize
        optimizer.zero_grad() # empty the gradient
        loss.backward()
        optimizer.step()

    print()
    print(f'Epoch {epoch}')
    return run_loss / cnt, correct / total

# Validation

In [None]:
def validation(model, val_loader, criterion):
    model.eval()
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        n_class_correct = [ 0 for i in range(10)]
        n_class_samples = [ 0 for i in range(10)]

        correct, total, cnt = 0, 0, 0
        run_loss = 0.0

        for batch_idx, (images, labels) in enumerate(val_loader):
            images, labels = images.to(device), labels.to(device)

            #outputs = model(images)[-1]
            outputs = model(images)
            outputs_magnitude = outputs.abs()
            loss = criterion(outputs_magnitude, labels)

            run_loss += loss
            _, predicted = torch.max(outputs_magnitude.data, 1)
            total += labels.size(0)
            cnt += 1
            correct += (predicted == labels).sum().item()
            for i in range(min(batch_size, len(labels))):
                label = labels[i]
                pred = predicted[i]
                if (label == pred):
                    n_class_correct[label] += 1
                n_class_samples[label] += 1
        acc = 100. * correct/total
        print(f'Accuracy of the network: {acc} %')
        #
        # for i in range(10):
        #     acc = 100. * n_class_correct[i] / n_class_samples[i]
        #     print(f'Accuracy of {classes[i]} : {acc} %')

    return run_loss / cnt, correct / total

#Poisson noise

In [None]:
def poisson_noise(images, lmba=1, clip=False):

    shape = False
    if len(images.shape) == 3 :
        images = rearrange(images, 'b c h -> h b c')
        images = images[None,...]
        shape = True

    rng = np.random.default_rng(None)

    for i in range(images.shape[0]):
        image = images[i]

        # Determine unique values in image & calculate the next power of two
        vals = len(np.unique(image))
        vals = 2 ** np.ceil(np.log2(vals))

        # Detect if a signed image was input
        if image.min() < 0:
            low_clip = -1.0
        else:
            low_clip = 0.0

        # Ensure image is exclusively positive
        if low_clip == -1.0:
            old_max = image.max()
            image = (image + 1.0) / (old_max + 1.0)

        # Generating noise for each unique value in image.
        out = image * rng.poisson(lmba) / float(vals)

        # Return image to original range if input was signed
        if low_clip == -1.0:
            out = out * (old_max + 1.0) - 1.0

        if clip:
            out = np.clip(out, low_clip, 1.0)

        #images[i] = torch.from_numpy(out)
        images[i] = out

    if shape :
        images.squeeze_(0)
        images = rearrange(images, 'b c h -> c h b')

    return images

# Testing

In [None]:
def testing(model, test_loader, criterion, noise_type, rgb_loader, noise_level):
    # put the model in evaluation mode
    model.eval()

    with torch.no_grad(): # don't need the backward propagation
        n_correct = 0
        n_samples = 0
        n_class_correct = [ 0 for i in range(10)]
        n_class_samples = [ 0 for i in range(10)]

        correct, total, cnt = 0, 0, 0
        run_loss = 0.0

        for batch_idx, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            if noise_type == 'gaussian':
                images = torch.tensor(random_noise(images.cpu(), mode='gaussian', mean=noise_level, clip=True),
                                      dtype=torch.float).to(device) #var=noiseLevel
            elif noise_type == 's&p':
                images = torch.tensor(random_noise(images.cpu(), mode='s&p', amount=noise_level, clip=True),
                                      dtype=torch.float).to(device) #, amount=noiseLevel
            elif noise_type == 'localvar':
                noise_tensor = torch.ones(images.shape)*noise_level
                images = torch.tensor(random_noise(images.cpu(), mode='localvar', local_vars=noise_tensor, clip=True),
                                      dtype=torch.float).to(device)
            elif noise_type == 'speckle':
                images = torch.tensor(random_noise(images.cpu(), mode='speckle', mean=noise_level),
                                      dtype=torch.float).to(device) #var=noiseLevel #clip=True
            elif noise_type == 'fog':
                images = torch.tensor(add_fog(images.cpu(), intensity=noise_level),
                                      dtype=torch.float).to(device)
            elif noise_type == 'poisson':
                images = torch.tensor(poisson_noise(images.cpu(), lmba=noise_level, clip=True),
                                      dtype=torch.float).to(device)
            outputs = model(images)
            #c1, mx1, c2, mx2, c3, c4, c5, mx5, avg, c6, c7, c8, outputs = model(images)

            if not cnt:
                idx = 0
                # rgb_plot(rgb_loader, idx, noise_level, noise_type)
                #plot_image(c1, mx1, c2, mx2, c3, c4, c5, mx5, idx)
                #polar_plot(c1, mx1, c2, mx2, c3, c4, c5, mx5, avg, c6, c7, c8, idx)

            outputs_magnitude = outputs.abs()
            loss = criterion(outputs_magnitude, labels)

            run_loss += loss
            _, predicted = torch.max(outputs_magnitude.data, 1)
            total += labels.size(0)
            cnt += 1
            correct += (predicted == labels).sum().item()
            for i in range(min(batch_size, len(labels))):
                label = labels[i]
                pred = predicted[i]
                if (label == pred):
                    n_class_correct[label] += 1
                n_class_samples[label] += 1
        acc = 100. * correct/total
        #print(f'Accuracy of the network: {acc} %')

    return run_loss / cnt, correct / total

# Main

In [None]:
def main(num_epochs, batch_size, learning_rate, classes, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, noise_type=None, load=False, save=False):


    run = wandb.init(project='internship-anna', name=args.model+"_test", config={
            "epochs": num_epochs,
            "batch_size": batch_size,
            "lr": learning_rate,
            "model": './results/'}, save_code=True)

    config = run.config

    train_loader = dataloaders.make_loader(train_loader, batch_size)
    val_loader = dataloaders.make_loader(val_loader, batch_size)
    RGBtrain_loader = dataloaders.RGBtest_data()
    RGBtrain_loader = dataloaders.make_loader(RGBtrain_loader, batch_size)


    #if args.model == 'AlexNet_real_small' or args.model == 'AlexNet_complex_bio':
    if args.model == 'AlexNet_complex':
        #ComplexWeigth_AlexNet, AlexNet
        model = AlexNet(num_classes=args.num_classes).to(device)
    elif args.model == 'VGG11_complex' or 'VGG11_real':
        model = VGG11(num_classes=args.num_classes).to(device)
    elif args.model == 'VGG13_complex' or 'VGG13_real':
        model = VGG13(num_classes=args.num_classes).to(device)
    elif args.model == 'VGG16_complex' or 'VGG16_real':
        model = VGG16(num_classes=args.num_classes).to(device)
    elif args.model == 'VGG19_complex' or 'VGG19_real':
        model = VGG19(num_classes=args.num_classes).to(device)
    else :
        model = AlexNet(num_classes=args.num_classes).to(device)

    if load :
        model_path = loader_path+f"/{args.model}.pth"
        dict_loaded = torch.load(model_path)
        model.load_state_dict(dict_loaded['model'])
        epochs = dict_loaded['epoch']
        best_test_acc = dict_loaded['acc']
    else :
        epochs = 0
        best_test_acc = 0.0

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss().to(device)
    run.watch(model)
    #print(sum([p.numel() for p in model.parameters() if p.requires_grad]))

    # variables for storing losses during one epoch
    train_loss, train_acc = [], []

    for epoch in range(epochs, num_epochs+epochs):

        train_loss, train_acc = training(model, num_epochs, epoch, train_loader, optimizer, criterion)
        val_loss, val_acc = validation(model, val_loader, criterion)

        if val_acc > best_test_acc:
            print(f"Saving ... ")
            state = {
                    'model': model.state_dict(),
                    'acc': val_acc,
                    'epoch': epoch,
                    }

            if save :
                torch.save(state, model_path)
            best_test_acc = val_acc


        print(f"Train_loss: {train_loss}")
        print(f"Val: {val_loss}")


        wandb.log({
            "Epoch": epoch + 1,
            "Train Loss": train_loss,
            "Train Accuracy": train_acc,
            "Val Loss": val_loss,
            "Val Accuracy": val_acc,
            "Best Accuracy": best_test_acc
        })

    test_loader = dataloaders.make_loader(test_loader, batch_size)
    if noise_type == 'gaussian':
        noiseLevel = [0, .15, .25, .35, .45, .6]
    elif noise_type == 's&p':
        noiseLevel = [0, .01, .03, .06, .1, .2]
    elif noise_type == 'localvar':
        noiseLevel = [.01, .03, .06, .1, .2]
    elif noise_type == 'speckle':
        noiseLevel = [0, 1, 5, 10, 15, 30]
    elif noise_type == 'fog':
        noiseLevel = [i/10 for i in range(8)]
    elif noise_type == 'poisson':
        noiseLevel = [20, 50, 100, 300, 600, 900, 1200]
    else:
        noiseLevel = []

    for noise_level in noiseLevel:
        test_loss, test_acc = testing(model, test_loader, criterion, noise_type, RGBtrain_loader, noise_level)
        print(f"{noise_type} noise, level {noise_level}")
        print(f"Test Loss: {test_loss}")
        print(f"Test Accuracy: {test_acc*100} %")
        print()

        wandb.log({
            "Test Loss": test_loss,
            "Test Accuracy": test_acc
        })


    run.finish()

    return model

model = main(args.epochs, args.batch_size, args.learning_rate, classes, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, noise_type=args.noise_type, load=args.load, save=args.save)



Files already downloaded and verified

Epoch 0
Accuracy of the network: 9.53 %
Saving ... 
Train_loss: 2.451591078642827
Val: 2.315023899078369


VBox(children=(Label(value='0.084 MB of 0.084 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Best Accuracy,▁
Epoch,▁
Train Accuracy,▁
Train Loss,▁
Val Accuracy,▁
Val Loss,▁

0,1
Best Accuracy,0.0953
Epoch,1.0
Train Accuracy,0.0953
Train Loss,2.45159
Val Accuracy,0.0953
Val Loss,2.31502
