# Notebook for testing the PyTorch setup

This netbook is for testing the [PyTorch](http://pytorch.org/) setup.  Below is a set of required imports.  

Run the cell, and no error messages should appear.

Some warnings may appear, this should be fine.

In [None]:
%matplotlib inline

import os

import torch
import torchvision
#from torch import nn
from torch.utils.data import DataLoader
#from torchvision import datasets
import torchvision.transforms as transforms

if not os.path.isfile('pml_utils.py'):
  !wget https://raw.githubusercontent.com/csc-training/intro-to-dl/master/day1/pml_utils.py
from pml_utils import show_failures

#from sklearn.model_selection import train_test_split

from packaging.version import Version as LV

#import numpy as np
#import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

print('Using PyTorch version:', torch.__version__)
assert(LV(torch.__version__) >= LV("2.0"))

Let's check if we have GPU available.

In [None]:
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f'GPU {i}:', torch.cuda.get_device_name(i))
else:
    print('No GPU, using CPU instead.') 

## Getting started

This section is adapted from the [PyTorch Quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html).

PyTorch has two classes from `torch.utils.data` to work with data: 
- `Dataset` which represents the actual data items, such as images or pieces of text, and their labels
- `DataLoader` which is used for processing the dataset in batches in an efficient manner.

PyTorch has domain-specific libraries with utilities for common data types such as [TorchText](https://pytorch.org/text/stable/index.html), [TorchVision](https://pytorch.org/vision/stable/index.html) and [TorchAudio](https://pytorch.org/audio/stable/index.html).

Here we will use TorchVision and `torchvision.datasets` which provides easy access to [many common visual datasets](https://pytorch.org/vision/stable/datasets.html). In this example we'll use the `FakeData` class which just generates random data.

In [None]:
training_data = torchvision.datasets.FakeData(transform=transforms.ToTensor())

In [None]:
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)

In [None]:
for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break