From 6a0591b84f84e0e5fb6da7df3d5e3af1487ec5e2 Mon Sep 17 00:00:00 2001 From: Phillip Lippe Date: Tue, 31 Aug 2021 16:56:59 +0200 Subject: [PATCH] Adding UvA-DL notebook "Introduction to PyTorch" (#73) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris --- .../introduction-to-pytorch/.meta.yml | 14 + .../Introduction_to_PyTorch.py | 986 ++++++++++++++++++ .../comparison_CPU_GPU.png | Bin 0 -> 37383 bytes .../continuous_xor.svg | 1 + .../pytorch_computation_graph.svg | 1 + .../small_neural_network.svg | 1 + 6 files changed, 1003 insertions(+) create mode 100644 course_UvA-DL/introduction-to-pytorch/.meta.yml create mode 100644 course_UvA-DL/introduction-to-pytorch/Introduction_to_PyTorch.py create mode 100644 course_UvA-DL/introduction-to-pytorch/comparison_CPU_GPU.png create mode 100644 course_UvA-DL/introduction-to-pytorch/continuous_xor.svg create mode 100644 course_UvA-DL/introduction-to-pytorch/pytorch_computation_graph.svg create mode 100644 course_UvA-DL/introduction-to-pytorch/small_neural_network.svg diff --git a/course_UvA-DL/introduction-to-pytorch/.meta.yml b/course_UvA-DL/introduction-to-pytorch/.meta.yml new file mode 100644 index 000000000..1a3726b25 --- /dev/null +++ b/course_UvA-DL/introduction-to-pytorch/.meta.yml @@ -0,0 +1,14 @@ +title: Introduction to PyTorch +author: Phillip Lippe +created: 2021-08-27 +updated: 2021-08-27 +license: CC BY-SA +description: | + This tutorial will give a short introduction to PyTorch basics, and get you setup for writing your own neural networks. + This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. + The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io. +requirements: + - matplotlib +accelerator: + - CPU + - GPU diff --git a/course_UvA-DL/introduction-to-pytorch/Introduction_to_PyTorch.py b/course_UvA-DL/introduction-to-pytorch/Introduction_to_PyTorch.py new file mode 100644 index 000000000..18693f734 --- /dev/null +++ b/course_UvA-DL/introduction-to-pytorch/Introduction_to_PyTorch.py @@ -0,0 +1,986 @@ +# %% [markdown] +# Welcome to our PyTorch tutorial for the Deep Learning course 2020 at the University of Amsterdam! +# The following notebook is meant to give a short introduction to PyTorch basics, and get you setup for writing your own neural networks. +# PyTorch is an open source machine learning framework that allows you to write your own neural networks and optimize them efficiently. +# However, PyTorch is not the only framework of its kind. +# Alternatives to PyTorch include [TensorFlow](https://www.tensorflow.org/), [JAX](https://github.com/google/jax#quickstart-colab-in-the-cloud) and [Caffe](http://caffe.berkeleyvision.org/). +# We choose to teach PyTorch at the University of Amsterdam because it is well established, has a huge developer community (originally developed by Facebook), is very flexible and especially used in research. +# Many current papers publish their code in PyTorch, and thus it is good to be familiar with PyTorch as well. +# Meanwhile, TensorFlow (developed by Google) is usually known for being a production-grade deep learning library. +# Still, if you know one machine learning framework in depth, it is very easy to learn another one because many of them use the same concepts and ideas. +# For instance, TensorFlow's version 2 was heavily inspired by the most popular features of PyTorch, making the frameworks even more similar. +# If you are already familiar with PyTorch and have created your own neural network projects, feel free to just skim this notebook. +# +# We are of course not the first ones to create a PyTorch tutorial. +# There are many great tutorials online, including the ["60-min blitz"](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) on the official [PyTorch website](https://pytorch.org/tutorials/). +# Yet, we choose to create our own tutorial which is designed to give you the basics particularly necessary for the practicals, but still understand how PyTorch works under the hood. +# Over the next few weeks, we will also keep exploring new PyTorch features in the series of Jupyter notebook tutorials about deep learning. +# +# We will use a set of standard libraries that are often used in machine learning projects. +# If you are running this notebook on Google Colab, all libraries should be pre-installed. +# If you are running this notebook locally, make sure you have installed our `dl2020` environment ([link](https://github.com/uvadlc/uvadlc_practicals_2020/blob/master/environment.yml)) and have activated it. + +# %% +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.utils.data as data + +# %matplotlib inline +from IPython.display import set_matplotlib_formats +from matplotlib.colors import to_rgba +from tqdm.notebook import tqdm # Progress bar + +set_matplotlib_formats("svg", "pdf") + +# %% [markdown] +# ## The Basics of PyTorch +# +# We will start with reviewing the very basic concepts of PyTorch. +# As a prerequisite, we recommend to be familiar with the `numpy` package as most machine learning frameworks are based on very similar concepts. +# If you are not familiar with numpy yet, don't worry: here is a [tutorial](https://numpy.org/devdocs/user/quickstart.html) to go through. +# +# So, let's start with importing PyTorch. +# The package is called `torch`, based on its original framework [Torch](http://torch.ch/). +# As a first step, we can check its version: + +# %% +print("Using torch", torch.__version__) + +# %% [markdown] +# At the time of writing this tutorial (mid of August 2021), the current stable version is 1.9. +# You should therefore see the output `Using torch 1.9.0`, eventually with some extension for the CUDA version on Colab. +# In case you use the `dl2020` environment, you should see `Using torch 1.6.0` since the environment was provided in October 2020. +# It is recommended to update the PyTorch version to the newest one. +# If you see a lower version number than 1.6, make sure you have installed the correct the environment, or ask one of your TAs. +# In case PyTorch 1.10 or newer will be published during the time of the course, don't worry. +# The interface between PyTorch versions doesn't change too much, and hence all code should also be runnable with newer versions. +# +# As in every machine learning framework, PyTorch provides functions that are stochastic like generating random numbers. +# However, a very good practice is to setup your code to be reproducible with the exact same random numbers. +# This is why we set a seed below. + +# %% +torch.manual_seed(42) # Setting the seed + +# %% [markdown] +# ### Tensors +# +# Tensors are the PyTorch equivalent to Numpy arrays, with the addition to also have support for GPU acceleration (more on that later). +# The name "tensor" is a generalization of concepts you already know. +# For instance, a vector is a 1-D tensor, and a matrix a 2-D tensor. +# When working with neural networks, we will use tensors of various shapes and number of dimensions. +# +# Most common functions you know from numpy can be used on tensors as well. +# Actually, since numpy arrays are so similar to tensors, we can convert most tensors to numpy arrays (and back) but we don't need it too often. +# +# #### Initialization +# +# Let's first start by looking at different ways of creating a tensor. +# There are many possible options, the most simple one is to call +# `torch.Tensor` passing the desired shape as input argument: + +# %% +x = torch.Tensor(2, 3, 4) +print(x) + +# %% [markdown] +# The function `torch.Tensor` allocates memory for the desired tensor, but reuses any values that have already been in the memory. +# To directly assign values to the tensor during initialization, there are many alternatives including: +# +# * `torch.zeros`: Creates a tensor filled with zeros +# * `torch.ones`: Creates a tensor filled with ones +# * `torch.rand`: Creates a tensor with random values uniformly sampled between 0 and 1 +# * `torch.randn`: Creates a tensor with random values sampled from a normal distribution with mean 0 and variance 1 +# * `torch.arange`: Creates a tensor containing the values $N,N+1,N+2,...,M$ +# * `torch.Tensor` (input list): Creates a tensor from the list elements you provide + +# %% +# Create a tensor from a (nested) list +x = torch.Tensor([[1, 2], [3, 4]]) +print(x) + +# %% +# Create a tensor with random values between 0 and 1 with the shape [2, 3, 4] +x = torch.rand(2, 3, 4) +print(x) + +# %% [markdown] +# You can obtain the shape of a tensor in the same way as in numpy (`x.shape`), or using the `.size` method: + +# %% +shape = x.shape +print("Shape:", x.shape) + +size = x.size() +print("Size:", size) + +dim1, dim2, dim3 = x.size() +print("Size:", dim1, dim2, dim3) + +# %% [markdown] +# #### Tensor to Numpy, and Numpy to Tensor +# +# Tensors can be converted to numpy arrays, and numpy arrays back to tensors. +# To transform a numpy array into a tensor, we can use the function `torch.from_numpy`: + +# %% +np_arr = np.array([[1, 2], [3, 4]]) +tensor = torch.from_numpy(np_arr) + +print("Numpy array:", np_arr) +print("PyTorch tensor:", tensor) + +# %% [markdown] +# To transform a PyTorch tensor back to a numpy array, we can use the function `.numpy()` on tensors: + +# %% +tensor = torch.arange(4) +np_arr = tensor.numpy() + +print("PyTorch tensor:", tensor) +print("Numpy array:", np_arr) + +# %% [markdown] +# The conversion of tensors to numpy require the tensor to be on the CPU, and not the GPU (more on GPU support in a later section). +# In case you have a tensor on GPU, you need to call `.cpu()` on the tensor beforehand. +# Hence, you get a line like `np_arr = tensor.cpu().numpy()`. + +# %% [markdown] +# #### Operations +# +# Most operations that exist in numpy, also exist in PyTorch. +# A full list of operations can be found in the [PyTorch documentation](https://pytorch.org/docs/stable/tensors.html#), but we will review the most important ones here. +# +# The simplest operation is to add two tensors: + +# %% +x1 = torch.rand(2, 3) +x2 = torch.rand(2, 3) +y = x1 + x2 + +print("X1", x1) +print("X2", x2) +print("Y", y) + +# %% [markdown] +# Calling `x1 + x2` creates a new tensor containing the sum of the two inputs. +# However, we can also use in-place operations that are applied directly on the memory of a tensor. +# We therefore change the values of `x2` without the chance to re-accessing the values of `x2` before the operation. +# An example is shown below: + +# %% +x1 = torch.rand(2, 3) +x2 = torch.rand(2, 3) +print("X1 (before)", x1) +print("X2 (before)", x2) + +x2.add_(x1) +print("X1 (after)", x1) +print("X2 (after)", x2) + +# %% [markdown] +# In-place operations are usually marked with a underscore postfix (e.g. "add_" instead of "add"). +# +# Another common operation aims at changing the shape of a tensor. +# A tensor of size (2,3) can be re-organized to any other shape with the same number of elements (e.g. a tensor of size (6), or (3,2), ...). +# In PyTorch, this operation is called `view`: + +# %% +x = torch.arange(6) +print("X", x) + +# %% +x = x.view(2, 3) +print("X", x) + +# %% +x = x.permute(1, 0) # Swapping dimension 0 and 1 +print("X", x) + +# %% [markdown] +# Other commonly used operations include matrix multiplications, which are essential for neural networks. +# Quite often, we have an input vector $\mathbf{x}$, which is transformed using a learned weight matrix $\mathbf{W}$. +# There are multiple ways and functions to perform matrix multiplication, some of which we list below: +# +# * `torch.matmul`: Performs the matrix product over two tensors, where the specific behavior depends on the dimensions. +# If both inputs are matrices (2-dimensional tensors), it performs the standard matrix product. +# For higher dimensional inputs, the function supports broadcasting (for details see the [documentation](https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=matmul#torch.matmul)). +# Can also be written as `a @ b`, similar to numpy. +# * `torch.mm`: Performs the matrix product over two matrices, but doesn't support broadcasting (see [documentation](https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=torch%20mm#torch.mm)) +# * `torch.bmm`: Performs the matrix product with a support batch dimension. +# If the first tensor $T$ is of shape ($b\times n\times m$), and the second tensor $R$ ($b\times m\times p$), the output $O$ is of shape ($b\times n\times p$), and has been calculated by performing $b$ matrix multiplications of the submatrices of $T$ and $R$: $O_i = T_i @ R_i$ +# * `torch.einsum`: Performs matrix multiplications and more (i.e. sums of products) using the Einstein summation convention. +# Explanation of the Einstein sum can be found in assignment 1. +# +# Usually, we use `torch.matmul` or `torch.bmm`. We can try a matrix multiplication with `torch.matmul` below. + +# %% +x = torch.arange(6) +x = x.view(2, 3) +print("X", x) + +# %% +W = torch.arange(9).view(3, 3) # We can also stack multiple operations in a single line +print("W", W) + +# %% +h = torch.matmul(x, W) # Verify the result by calculating it by hand too! +print("h", h) + +# %% [markdown] +# #### Indexing +# +# We often have the situation where we need to select a part of a tensor. +# Indexing works just like in numpy, so let's try it: + +# %% +x = torch.arange(12).view(3, 4) +print("X", x) + +# %% +print(x[:, 1]) # Second column + +# %% +print(x[0]) # First row + +# %% +print(x[:2, -1]) # First two rows, last column + +# %% +print(x[1:3, :]) # Middle two rows + +# %% [markdown] +# ### Dynamic Computation Graph and Backpropagation +# +# One of the main reasons for using PyTorch in Deep Learning projects is that we can automatically get **gradients/derivatives** of functions that we define. +# We will mainly use PyTorch for implementing neural networks, and they are just fancy functions. +# If we use weight matrices in our function that we want to learn, then those are called the **parameters** or simply the **weights**. +# +# If our neural network would output a single scalar value, we would talk about taking the **derivative**, but you will see that quite often we will have **multiple** output variables ("values"); in that case we talk about **gradients**. +# It's a more general term. +# +# Given an input $\mathbf{x}$, we define our function by **manipulating** that input, usually by matrix-multiplications with weight matrices and additions with so-called bias vectors. +# As we manipulate our input, we are automatically creating a **computational graph**. +# This graph shows how to arrive at our output from our input. +# PyTorch is a **define-by-run** framework; this means that we can just do our manipulations, and PyTorch will keep track of that graph for us. +# Thus, we create a dynamic computation graph along the way. +# +# So, to recap: the only thing we have to do is to compute the **output**, and then we can ask PyTorch to automatically get the **gradients**. +# +# > **Note: Why do we want gradients? +# ** Consider that we have defined a function, a neural net, that is supposed to compute a certain output $y$ for an input vector $\mathbf{x}$. +# We then define an **error measure** that tells us how wrong our network is; how bad it is in predicting output $y$ from input $\mathbf{x}$. +# Based on this error measure, we can use the gradients to **update** the weights $\mathbf{W}$ that were responsible for the output, so that the next time we present input $\mathbf{x}$ to our network, the output will be closer to what we want. +# +# The first thing we have to do is to specify which tensors require gradients. +# By default, when we create a tensor, it does not require gradients. + +# %% +x = torch.ones((3,)) +print(x.requires_grad) + +# %% [markdown] +# We can change this for an existing tensor using the function `requires_grad_()` (underscore indicating that this is a in-place operation). +# Alternatively, when creating a tensor, you can pass the argument +# `requires_grad=True` to most initializers we have seen above. + +# %% +x.requires_grad_(True) +print(x.requires_grad) + +# %% [markdown] +# In order to get familiar with the concept of a computation graph, we will create one for the following function: +# +# $$y = \frac{1}{|x|}\sum_i \left[(x_i + 2)^2 + 3\right]$$ +# +# You could imagine that $x$ are our parameters, and we want to optimize (either maximize or minimize) the output $y$. +# For this, we want to obtain the gradients $\partial y / \partial \mathbf{x}$. +# For our example, we'll use $\mathbf{x}=[0,1,2]$ as our input. + +# %% +x = torch.arange(3, dtype=torch.float32, requires_grad=True) # Only float tensors can have gradients +print("X", x) + +# %% [markdown] +# Now let's build the computation graph step by step. +# You can combine multiple operations in a single line, but we will +# separate them here to get a better understanding of how each operation +# is added to the computation graph. + +# %% +a = x + 2 +b = a ** 2 +c = b + 3 +y = c.mean() +print("Y", y) + +# %% [markdown] +# Using the statements above, we have created a computation graph that looks similar to the figure below: +# +#
+# +# We calculate $a$ based on the inputs $x$ and the constant $2$, $b$ is $a$ squared, and so on. +# The visualization is an abstraction of the dependencies between inputs and outputs of the operations we have applied. +# Each node of the computation graph has automatically defined a function for calculating the gradients with respect to its inputs, `grad_fn`. +# You can see this when we printed the output tensor $y$. +# This is why the computation graph is usually visualized in the reverse direction (arrows point from the result to the inputs). +# We can perform backpropagation on the computation graph by calling the +# function `backward()` on the last output, which effectively calculates +# the gradients for each tensor that has the property +# `requires_grad=True`: + +# %% +y.backward() + +# %% [markdown] +# `x.grad` will now contain the gradient $\partial y/ \partial \mathcal{x}$, and this gradient indicates how a change in $\mathbf{x}$ will affect output $y$ given the current input $\mathbf{x}=[0,1,2]$: + +# %% +print(x.grad) + +# %% [markdown] +# We can also verify these gradients by hand. +# We will calculate the gradients using the chain rule, in the same way as PyTorch did it: +# +# $$\frac{\partial y}{\partial x_i} = \frac{\partial y}{\partial c_i}\frac{\partial c_i}{\partial b_i}\frac{\partial b_i}{\partial a_i}\frac{\partial a_i}{\partial x_i}$$ +# +# Note that we have simplified this equation to index notation, and by using the fact that all operation besides the mean do not combine the elements in the tensor. +# The partial derivatives are: +# +# $$ +# \frac{\partial a_i}{\partial x_i} = 1,\hspace{1cm} +# \frac{\partial b_i}{\partial a_i} = 2\cdot a_i\hspace{1cm} +# \frac{\partial c_i}{\partial b_i} = 1\hspace{1cm} +# \frac{\partial y}{\partial c_i} = \frac{1}{3} +# $$ +# +# Hence, with the input being $\mathbf{x}=[0,1,2]$, our gradients are $\partial y/\partial \mathbf{x}=[4/3,2,8/3]$. +# The previous code cell should have printed the same result. + +# %% [markdown] +# ### GPU support +# +# A crucial feature of PyTorch is the support of GPUs, short for Graphics Processing Unit. +# A GPU can perform many thousands of small operations in parallel, making it very well suitable for performing large matrix operations in neural networks. +# When comparing GPUs to CPUs, we can list the following main differences (credit: [Kevin Krewell, 2009](https://blogs.nvidia.com/blog/2009/12/16/whats-the-difference-between-a-cpu-and-a-gpu/)) +# +#
+# +# CPUs and GPUs have both different advantages and disadvantages, which is why many computers contain both components and use them for different tasks. +# In case you are not familiar with GPUs, you can read up more details in this [NVIDIA blog post](https://blogs.nvidia.com/blog/2009/12/16/whats-the-difference-between-a-cpu-and-a-gpu/) or [here](https://www.intel.com/content/www/us/en/products/docs/processors/what-is-a-gpu.html). +# +# GPUs can accelerate the training of your network up to a factor of $100$ which is essential for large neural networks. +# PyTorch implements a lot of functionality for supporting GPUs (mostly those of NVIDIA due to the libraries [CUDA](https://developer.nvidia.com/cuda-zone) and [cuDNN](https://developer.nvidia.com/cudnn)). +# First, let's check whether you have a GPU available: + +# %% +gpu_avail = torch.cuda.is_available() +print(f"Is the GPU available? {gpu_avail}") + +# %% [markdown] +# If you have a GPU on your computer but the command above returns False, make sure you have the correct CUDA-version installed. +# The `dl2020` environment comes with the CUDA-toolkit 10.1, which is selected for the Lisa supercomputer. +# Please change it if necessary (CUDA 10.2 is currently common). +# On Google Colab, make sure that you have selected a GPU in your runtime setup (in the menu, check under `Runtime -> Change runtime type`). +# +# By default, all tensors you create are stored on the CPU. +# We can push a tensor to the GPU by using the function `.to(...)`, or `.cuda()`. +# However, it is often a good practice to define a `device` object in your code which points to the GPU if you have one, and otherwise to the CPU. +# Then, you can write your code with respect to this device object, and it allows you to run the same code on both a CPU-only system, and one with a GPU. +# Let's try it below. +# We can specify the device as follows: + +# %% +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +print("Device", device) + +# %% [markdown] +# Now let's create a tensor and push it to the device: + +# %% +x = torch.zeros(2, 3) +x = x.to(device) +print("X", x) + +# %% [markdown] +# In case you have a GPU, you should now see the attribute `device='cuda:0'` being printed next to your tensor. +# The zero next to cuda indicates that this is the zero-th GPU device on your computer. +# PyTorch also supports multi-GPU systems, but this you will only need once you have very big networks to train (if interested, see the [PyTorch documentation](https://pytorch.org/docs/stable/distributed.html#distributed-basics)). +# We can also compare the runtime of a large matrix multiplication on the CPU with a operation on the GPU: + +# %% +x = torch.randn(5000, 5000) + +# CPU version +start_time = time.time() +_ = torch.matmul(x, x) +end_time = time.time() +print(f"CPU time: {(end_time - start_time):6.5f}s") + +# GPU version +x = x.to(device) +# The first operation on a CUDA device can be slow as it has to establish a CPU-GPU communication first. +# Hence, we run an arbitrary command first without timing it for a fair comparison. +if torch.cuda.is_available(): + _ = torch.matmul(x * 0.0, x) +start_time = time.time() +_ = torch.matmul(x, x) +end_time = time.time() +print(f"GPU time: {(end_time - start_time):6.5f}s") + +# %% [markdown] +# Depending on the size of the operation and the CPU/GPU in your system, the speedup of this operation can be >500x. +# As `matmul` operations are very common in neural networks, we can already see the great benefit of training a NN on a GPU. +# The time estimate can be relatively noisy here because we haven't run it for multiple times. +# Feel free to extend this, but it also takes longer to run. +# +# When generating random numbers, the seed between CPU and GPU is not synchronized. +# Hence, we need to set the seed on the GPU separately to ensure a reproducible code. +# Note that due to different GPU architectures, running the same code on different GPUs does not guarantee the same random numbers. +# Still, we don't want that our code gives us a different output every time we run it on the exact same hardware. +# Hence, we also set the seed on the GPU: + +# %% +# GPU operations have a separate seed we also want to set +if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + torch.cuda.manual_seed_all(42) + +# Additionally, some operations on a GPU are implemented stochastic for efficiency +# We want to ensure that all operations are deterministic on GPU (if used) for reproducibility +torch.backends.cudnn.determinstic = True +torch.backends.cudnn.benchmark = False + +# %% [markdown] +# ## Learning by example: Continuous XOR +# +# If we want to build a neural network in PyTorch, we could specify all our parameters (weight matrices, bias vectors) using `Tensors` (with `requires_grad=True`), ask PyTorch to calculate the gradients and then adjust the parameters. +# But things can quickly get cumbersome if we have a lot of parameters. +# In PyTorch, there is a package called `torch.nn` that makes building neural networks more convenient. +# +# We will introduce the libraries and all additional parts you might need to train a neural network in PyTorch, using a simple example classifier on a simple yet well known example: XOR. +# Given two binary inputs $x_1$ and $x_2$, the label to predict is $1$ if either $x_1$ or $x_2$ is $1$ while the other is $0$, or the label is $0$ in all other cases. +# The example became famous by the fact that a single neuron, i.e. a linear classifier, cannot learn this simple function. +# Hence, we will learn how to build a small neural network that can learn this function. +# To make it a little bit more interesting, we move the XOR into continuous space and introduce some gaussian noise on the binary inputs. +# Our desired separation of an XOR dataset could look as follows: +# +#
+ +# %% [markdown] +# ### The model +# +# The package `torch.nn` defines a series of useful classes like linear networks layers, activation functions, loss functions etc. +# A full list can be found [here](https://pytorch.org/docs/stable/nn.html). +# In case you need a certain network layer, check the documentation of the package first before writing the layer yourself as the package likely contains the code for it already. +# We import it below: + +# %% +# %% + +# %% [markdown] +# Additionally to `torch.nn`, there is also `torch.nn.functional`. +# It contains functions that are used in network layers. +# This is in contrast to `torch.nn` which defines them as `nn.Modules` (more on it below), and `torch.nn` actually uses a lot of functionalities from `torch.nn.functional`. +# Hence, the functional package is useful in many situations, and so we import it as well here. + +# %% [markdown] +# #### nn.Module +# +# In PyTorch, a neural network is build up out of modules. +# Modules can contain other modules, and a neural network is considered to be a module itself as well. +# The basic template of a module is as follows: + + +# %% +class MyModule(nn.Module): + def __init__(self): + super().__init__() + # Some init for my module + + def forward(self, x): + # Function for performing the calculation of the module. + pass + + +# %% [markdown] +# The forward function is where the computation of the module is taken place, and is executed when you call the module (`nn = MyModule(); nn(x)`). +# In the init function, we usually create the parameters of the module, using `nn.Parameter`, or defining other modules that are used in the forward function. +# The backward calculation is done automatically, but could be overwritten as well if wanted. +# +# #### Simple classifier +# We can now make use of the pre-defined modules in the `torch.nn` package, and define our own small neural network. +# We will use a minimal network with a input layer, one hidden layer with tanh as activation function, and a output layer. +# In other words, our networks should look something like this: +# +#
+# +# The input neurons are shown in blue, which represent the coordinates $x_1$ and $x_2$ of a data point. +# The hidden neurons including a tanh activation are shown in white, and the output neuron in red. +# In PyTorch, we can define this as follows: + + +# %% +class SimpleClassifier(nn.Module): + def __init__(self, num_inputs, num_hidden, num_outputs): + super().__init__() + # Initialize the modules we need to build the network + self.linear1 = nn.Linear(num_inputs, num_hidden) + self.act_fn = nn.Tanh() + self.linear2 = nn.Linear(num_hidden, num_outputs) + + def forward(self, x): + # Perform the calculation of the model to determine the prediction + x = self.linear1(x) + x = self.act_fn(x) + x = self.linear2(x) + return x + + +# %% [markdown] +# For the examples in this notebook, we will use a tiny neural network with two input neurons and four hidden neurons. +# As we perform binary classification, we will use a single output neuron. +# Note that we do not apply a sigmoid on the output yet. +# This is because other functions, especially the loss, are more efficient and precise to calculate on the original outputs instead of the sigmoid output. +# We will discuss the detailed reason later. + +# %% +model = SimpleClassifier(num_inputs=2, num_hidden=4, num_outputs=1) +# Printing a module shows all its submodules +print(model) + +# %% [markdown] +# Printing the model lists all submodules it contains. +# The parameters of a module can be obtained by using its `parameters()` functions, or `named_parameters()` to get a name to each parameter object. +# For our small neural network, we have the following parameters: + +# %% +for name, param in model.named_parameters(): + print(f"Parameter {name}, shape {param.shape}") + +# %% [markdown] +# Each linear layer has a weight matrix of the shape `[output, input]`, and a bias of the shape `[output]`. +# The tanh activation function does not have any parameters. +# Note that parameters are only registered for `nn.Module` objects that are direct object attributes, i.e. `self.a = ...`. +# If you define a list of modules, the parameters of those are not registered for the outer module and can cause some issues when you try to optimize your module. +# There are alternatives, like `nn.ModuleList`, `nn.ModuleDict` and `nn.Sequential`, that allow you to have different data structures of modules. +# We will use them in a few later tutorials and explain them there. + +# %% [markdown] +# ### The data +# +# PyTorch also provides a few functionalities to load the training and +# test data efficiently, summarized in the package `torch.utils.data`. + +# %% + +# %% [markdown] +# The data package defines two classes which are the standard interface for handling data in PyTorch: `data.Dataset`, and `data.DataLoader`. +# The dataset class provides an uniform interface to access the +# training/test data, while the data loader makes sure to efficiently load +# and stack the data points from the dataset into batches during training. + +# %% [markdown] +# #### The dataset class +# +# The dataset class summarizes the basic functionality of a dataset in a natural way. +# To define a dataset in PyTorch, we simply specify two functions: `__getitem__`, and `__len__`. +# The get-item function has to return the $i$-th data point in the dataset, while the len function returns the size of the dataset. +# For the XOR dataset, we can define the dataset class as follows: + +# %% + + +class XORDataset(data.Dataset): + def __init__(self, size, std=0.1): + """ + Inputs: + size - Number of data points we want to generate + std - Standard deviation of the noise (see generate_continuous_xor function) + """ + super().__init__() + self.size = size + self.std = std + self.generate_continuous_xor() + + def generate_continuous_xor(self): + # Each data point in the XOR dataset has two variables, x and y, that can be either 0 or 1 + # The label is their XOR combination, i.e. 1 if only x or only y is 1 while the other is 0. + # If x=y, the label is 0. + data = torch.randint(low=0, high=2, size=(self.size, 2), dtype=torch.float32) + label = (data.sum(dim=1) == 1).to(torch.long) + # To make it slightly more challenging, we add a bit of gaussian noise to the data points. + data += self.std * torch.randn(data.shape) + + self.data = data + self.label = label + + def __len__(self): + # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0] + return self.size + + def __getitem__(self, idx): + # Return the idx-th data point of the dataset + # If we have multiple things to return (data point and label), we can return them as tuple + data_point = self.data[idx] + data_label = self.label[idx] + return data_point, data_label + + +# %% [markdown] +# Let's try to create such a dataset and inspect it: + +# %% +dataset = XORDataset(size=200) +print("Size of dataset:", len(dataset)) +print("Data point 0:", dataset[0]) + +# %% [markdown] +# To better relate to the dataset, we visualize the samples below. + + +# %% +def visualize_samples(data, label): + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + if isinstance(label, torch.Tensor): + label = label.cpu().numpy() + data_0 = data[label == 0] + data_1 = data[label == 1] + + plt.figure(figsize=(4, 4)) + plt.scatter(data_0[:, 0], data_0[:, 1], edgecolor="#333", label="Class 0") + plt.scatter(data_1[:, 0], data_1[:, 1], edgecolor="#333", label="Class 1") + plt.title("Dataset samples") + plt.ylabel(r"$x_2$") + plt.xlabel(r"$x_1$") + plt.legend() + + +# %% +visualize_samples(dataset.data, dataset.label) +plt.show() + +# %% [markdown] +# #### The data loader class +# +# The class `torch.utils.data.DataLoader` represents a Python iterable over a dataset with support for automatic batching, multi-process data loading and many more features. +# The data loader communicates with the dataset using the function `__getitem__`, and stacks its outputs as tensors over the first dimension to form a batch. +# In contrast to the dataset class, we usually don't have to define our own data loader class, but can just create an object of it with the dataset as input. +# Additionally, we can configure our data loader with the following input arguments (only a selection, see full list [here](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)): +# +# * `batch_size`: Number of samples to stack per batch +# * `shuffle`: If True, the data is returned in a random order. +# This is important during training for introducing stochasticity. +# * `num_workers`: Number of subprocesses to use for data loading. +# The default, 0, means that the data will be loaded in the main process which can slow down training for datasets where loading a data point takes a considerable amount of time (e.g. large images). +# More workers are recommended for those, but can cause issues on Windows computers. +# For tiny datasets as ours, 0 workers are usually faster. +# * `pin_memory`: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. +# This can save some time for large data points on GPUs. +# Usually a good practice to use for a training set, but not necessarily for validation and test to save memory on the GPU. +# * `drop_last`: If True, the last batch is dropped in case it is smaller than the specified batch size. +# This occurs when the dataset size is not a multiple of the batch size. +# Only potentially helpful during training to keep a consistent batch size. +# +# Let's create a simple data loader below: + +# %% +data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True) + +# %% +# next(iter(...)) catches the first batch of the data loader +# If shuffle is True, this will return a different batch every time we run this cell +# For iterating over the whole dataset, we can simple use "for batch in data_loader: ..." +data_inputs, data_labels = next(iter(data_loader)) + +# The shape of the outputs are [batch_size, d_1,...,d_N] where d_1,...,d_N are the +# dimensions of the data point returned from the dataset class +print("Data inputs", data_inputs.shape, "\n", data_inputs) +print("Data labels", data_labels.shape, "\n", data_labels) + +# %% [markdown] +# ### Optimization +# +# After defining the model and the dataset, it is time to prepare the optimization of the model. +# During training, we will perform the following steps: +# +# 1. Get a batch from the data loader +# 2. Obtain the predictions from the model for the batch +# 3. Calculate the loss based on the difference between predictions and labels +# 4. Backpropagation: calculate the gradients for every parameter with respect to the loss +# 5. Update the parameters of the model in the direction of the gradients +# +# We have seen how we can do step 1, 2 and 4 in PyTorch. Now, we will look at step 3 and 5. + +# %% [markdown] +# #### Loss modules +# +# We can calculate the loss for a batch by simply performing a few tensor operations as those are automatically added to the computation graph. +# For instance, for binary classification, we can use Binary Cross Entropy (BCE) which is defined as follows: +# +# $$\mathcal{L}_{BCE} = -\sum_i \left[ y_i \log x_i + (1 - y_i) \log (1 - x_i) \right]$$ +# +# where $y$ are our labels, and $x$ our predictions, both in the range of $[0,1]$. +# However, PyTorch already provides a list of predefined loss functions which we can use (see [here](https://pytorch.org/docs/stable/nn.html#loss-functions) for a full list). +# For instance, for BCE, PyTorch has two modules: `nn.BCELoss()`, `nn.BCEWithLogitsLoss()`. +# While `nn.BCELoss` expects the inputs $x$ to be in the range $[0,1]$, i.e. the output of a sigmoid, `nn.BCEWithLogitsLoss` combines a sigmoid layer and the BCE loss in a single class. +# This version is numerically more stable than using a plain Sigmoid followed by a BCE loss because of the logarithms applied in the loss function. +# Hence, it is adviced to use loss functions applied on "logits" where possible (remember to not apply a sigmoid on the output of the model in this case!). +# For our model defined above, we therefore use the module `nn.BCEWithLogitsLoss`. + +# %% +loss_module = nn.BCEWithLogitsLoss() + +# %% [markdown] +# #### Stochastic Gradient Descent +# +# For updating the parameters, PyTorch provides the package `torch.optim` that has most popular optimizers implemented. +# We will discuss the specific optimizers and their differences later in the course, but will for now use the simplest of them: `torch.optim.SGD`. +# Stochastic Gradient Descent updates parameters by multiplying the gradients with a small constant, called learning rate, and subtracting those from the parameters (hence minimizing the loss). +# Therefore, we slowly move towards the direction of minimizing the loss. +# A good default value of the learning rate for a small network as ours is 0.1. + +# %% +# Input to the optimizer are the parameters of the model: model.parameters() +optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + +# %% [markdown] +# The optimizer provides two useful functions: `optimizer.step()`, and `optimizer.zero_grad()`. +# The step function updates the parameters based on the gradients as explained above. +# The function `optimizer.zero_grad()` sets the gradients of all parameters to zero. +# While this function seems less relevant at first, it is a crucial pre-step before performing backpropagation. +# If we would call the `backward` function on the loss while the parameter gradients are non-zero from the previous batch, the new gradients would actually be added to the previous ones instead of overwriting them. +# This is done because a parameter might occur multiple times in a computation graph, and we need to sum the gradients in this case instead of replacing them. +# Hence, remember to call `optimizer.zero_grad()` before calculating the gradients of a batch. + +# %% [markdown] +# ### Training +# +# Finally, we are ready to train our model. +# As a first step, we create a slightly larger dataset and specify a data loader with a larger batch size. + +# %% +train_dataset = XORDataset(size=1000) +train_data_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True) + +# %% [markdown] +# Now, we can write a small training function. +# Remember our five steps: load a batch, obtain the predictions, calculate the loss, backpropagate, and update. +# Additionally, we have to push all data and model parameters to the device of our choice (GPU if available). +# For the tiny neural network we have, communicating the data to the GPU actually takes much more time than we could save from running the operation on GPU. +# For large networks, the communication time is significantly smaller than the actual runtime making a GPU crucial in these cases. +# Still, to practice, we will push the data to GPU here. + +# %% +# Push model to device. Has to be only done once +model.to(device) + +# %% [markdown] +# In addition, we set our model to training mode. +# This is done by calling `model.train()`. +# There exist certain modules that need to perform a different forward +# step during training than during testing (e.g. BatchNorm and Dropout), +# and we can switch between them using `model.train()` and `model.eval()`. + + +# %% +def train_model(model, optimizer, data_loader, loss_module, num_epochs=100): + # Set model to train mode + model.train() + + # Training loop + for epoch in tqdm(range(num_epochs)): + for data_inputs, data_labels in data_loader: + + # Step 1: Move input data to device (only strictly necessary if we use GPU) + data_inputs = data_inputs.to(device) + data_labels = data_labels.to(device) + + # Step 2: Run the model on the input data + preds = model(data_inputs) + preds = preds.squeeze(dim=1) # Output is [Batch size, 1], but we want [Batch size] + + # Step 3: Calculate the loss + loss = loss_module(preds, data_labels.float()) + + # Step 4: Perform backpropagation + # Before calculating the gradients, we need to ensure that they are all zero. + # The gradients would not be overwritten, but actually added to the existing ones. + optimizer.zero_grad() + # Perform backpropagation + loss.backward() + + # Step 5: Update the parameters + optimizer.step() + + +# %% +train_model(model, optimizer, train_data_loader, loss_module) + +# %% [markdown] +# #### Saving a model +# +# After finish training a model, we save the model to disk so that we can load the same weights at a later time. +# For this, we extract the so-called `state_dict` from the model which contains all learnable parameters. +# For our simple model, the state dict contains the following entries: + +# %% +state_dict = model.state_dict() +print(state_dict) + +# %% [markdown] +# To save the state dictionary, we can use `torch.save`: + +# %% +# torch.save(object, filename). For the filename, any extension can be used +torch.save(state_dict, "our_model.tar") + +# %% [markdown] +# To load a model from a state dict, we use the function `torch.load` to +# load the state dict from the disk, and the module function +# `load_state_dict` to overwrite our parameters with the new values: + +# %% +# Load state dict from the disk (make sure it is the same name as above) +state_dict = torch.load("our_model.tar") + +# Create a new model and load the state +new_model = SimpleClassifier(num_inputs=2, num_hidden=4, num_outputs=1) +new_model.load_state_dict(state_dict) + +# Verify that the parameters are the same +print("Original model\n", model.state_dict()) +print("\nLoaded model\n", new_model.state_dict()) + +# %% [markdown] +# A detailed tutorial on saving and loading models in PyTorch can be found +# [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html). + +# %% [markdown] +# ### Evaluation +# +# Once we have trained a model, it is time to evaluate it on a held-out test set. +# As our dataset consist of randomly generated data points, we need to +# first create a test set with a corresponding data loader. + +# %% +test_dataset = XORDataset(size=500) +# drop_last -> Don't drop the last batch although it is smaller than 128 +test_data_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False) + +# %% [markdown] +# As metric, we will use accuracy which is calculated as follows: +# +# $$acc = \frac{\#\text{correct predictions}}{\#\text{all predictions}} = \frac{TP+TN}{TP+TN+FP+FN}$$ +# +# where TP are the true positives, TN true negatives, FP false positives, and FN the fale negatives. +# +# When evaluating the model, we don't need to keep track of the computation graph as we don't intend to calculate the gradients. +# This reduces the required memory and speed up the model. +# In PyTorch, we can deactivate the computation graph using `with torch.no_grad(): ...`. +# Remember to additionally set the model to eval mode. + + +# %% +def eval_model(model, data_loader): + model.eval() # Set model to eval mode + true_preds, num_preds = 0.0, 0.0 + + with torch.no_grad(): # Deactivate gradients for the following code + for data_inputs, data_labels in data_loader: + + # Determine prediction of model on dev set + data_inputs, data_labels = data_inputs.to(device), data_labels.to(device) + preds = model(data_inputs) + preds = preds.squeeze(dim=1) + preds = torch.sigmoid(preds) # Sigmoid to map predictions between 0 and 1 + pred_labels = (preds >= 0.5).long() # Binarize predictions to 0 and 1 + + # Keep records of predictions for the accuracy metric (true_preds=TP+TN, num_preds=TP+TN+FP+FN) + true_preds += (pred_labels == data_labels).sum() + num_preds += data_labels.shape[0] + + acc = true_preds / num_preds + print(f"Accuracy of the model: {100.0*acc:4.2f}%") + + +# %% +eval_model(model, test_data_loader) + +# %% [markdown] +# If we trained our model correctly, we should see a score close to 100% accuracy. +# However, this is only possible because of our simple task, and +# unfortunately, we usually don't get such high scores on test sets of +# more complex tasks. + +# %% [markdown] +# #### Visualizing classification boundaries +# +# To visualize what our model has learned, we can perform a prediction for every data point in a range of $[-0.5, 1.5]$, and visualize the predicted class as in the sample figure at the beginning of this section. +# This shows where the model has created decision boundaries, and which points would be classified as $0$, and which as $1$. +# We therefore get a background image out of blue (class 0) and orange (class 1). +# The spots where the model is uncertain we will see a blurry overlap. +# The specific code is less relevant compared to the output figure which +# should hopefully show us a clear separation of classes: + + +# %% +@torch.no_grad() # Decorator, same effect as "with torch.no_grad(): ..." over the whole function. +def visualize_classification(model, data, label): + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + if isinstance(label, torch.Tensor): + label = label.cpu().numpy() + data_0 = data[label == 0] + data_1 = data[label == 1] + + plt.figure(figsize=(4, 4)) + plt.scatter(data_0[:, 0], data_0[:, 1], edgecolor="#333", label="Class 0") + plt.scatter(data_1[:, 0], data_1[:, 1], edgecolor="#333", label="Class 1") + plt.title("Dataset samples") + plt.ylabel(r"$x_2$") + plt.xlabel(r"$x_1$") + plt.legend() + + # Let's make use of a lot of operations we have learned above + model.to(device) + c0 = torch.Tensor(to_rgba("C0")).to(device) + c1 = torch.Tensor(to_rgba("C1")).to(device) + x1 = torch.arange(-0.5, 1.5, step=0.01, device=device) + x2 = torch.arange(-0.5, 1.5, step=0.01, device=device) + xx1, xx2 = torch.meshgrid(x1, x2) # Meshgrid function as in numpy + model_inputs = torch.stack([xx1, xx2], dim=-1) + preds = model(model_inputs) + preds = torch.sigmoid(preds) + # Specifying "None" in a dimension creates a new one + output_image = preds * c0[None, None] + (1 - preds) * c1[None, None] + output_image = ( + output_image.cpu().numpy() + ) # Convert to numpy array. This only works for tensors on CPU, hence first push to CPU + plt.imshow(output_image, origin="upper", extent=(-0.5, 1.5, -0.5, 1.5)) + plt.grid(False) + + +visualize_classification(model, dataset.data, dataset.label) +plt.show() + +# %% [markdown] +# The decision boundaries might not look exactly as in the figure in the preamble of this section which can be caused by running it on CPU or a different GPU architecture. +# Nevertheless, the result on the accuracy metric should be the approximately the same. + +# %% [markdown] +# ## Additional features we didn't get to discuss yet +# +# Finally, you are all set to start with your own PyTorch project! +# In summary, we have looked at how we can build neural networks in PyTorch, and train and test them on data. +# However, there is still much more to PyTorch we haven't discussed yet. +# In the comming series of Jupyter notebooks, we will discover more and more functionalities of PyTorch, so that you also get familiar to PyTorch concepts beyond the basics. +# If you are already interested in learning more of PyTorch, we recommend the official [tutorial website](https://pytorch.org/tutorials/) that contains many tutorials on various topics. +# Especially logging with Tensorboard ([tutorial +# here](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)) +# is a good practice that we will explore from Tutorial 5 on. diff --git a/course_UvA-DL/introduction-to-pytorch/comparison_CPU_GPU.png b/course_UvA-DL/introduction-to-pytorch/comparison_CPU_GPU.png new file mode 100644 index 0000000000000000000000000000000000000000..b7d94f91eab5b7d09193e73e3d0a86d21e0232ef GIT binary patch literal 37383 zcmdSBcQ}{t|3Ca{NP8#AC{eOXLUsdDN%kfoo9s>RhLot3Bq1c5WN!^*XUnW)hwRP$ zID3EY`}h6lK92jJ`}lo+j^q3Js&~Au>pGw3^Z8ikb^prc^IJC3ZKO~rTV$jq6eyIn z8WhSZz4hzxH>Y!}@dwIU+p{uC>+$8b-oP7w-e)JNW~XR%-Of?Z#*kuUX=P!^Wvg#v zXlQ9`Y-KmPrcjJRIY5z-IHTnB>UW2ulal1Z%9xxqmBN7yo9?h5P>u~}i%>ZGgyuoo z&6qz2n3>r&Ze(X;7rQ3$AjW%(Fe_Z`;4$ez!amLG}0BA{ptlfB){( zhWj-Ces|}Iw}`J-cO*9b{pRM=$HuGves^Q%POiV-P%fYS@Azfy(-J;< zwB=vEWOa0`rrgYZhCftRa{kuCWlHDB-hY1Q&f5It`Jp)V)VJ*&9X&sOgui}$Ky>cc zQ=uE*Z&vX$2sm6*2=hyc+VOSocm2DHJnjGr2W&hiyg&ECs z?R=MX$)$t@&Uhhtd3hRpKPK^!$syFziW|3=9xc_k4@ug&4N({Oh6p=W_gXtTI=YXS z9ibC(dMD|edcA70FU+%Za&c*XsvaNDp{lAXqhMq6-p$RevQ#xuBZ}XmRV_VUEBih3 zp+l@P_<_2w;pZmYWGR%_f7&@S*c9G)ce;*>=8uckbUG zKRw*0x%O;QQWDMm#}6NF-m>q6gjrKUnJUvnZ&m&0`@5{16DXALVKpk|KYDww@6a?b zGBT@t!Shnw)1zSiOL=*n_l`qpE>nY=jK?UH*_oqXB`VTP>UZzm>zEkHEIC3Y@6phupwKknt& z^jtWjgKaxYoyC)}e%m%x1qFY6rP=mf_SwgBEPqlV$t-hmyy95zEc3yGXE;<6u3+I= zw0&d`%f3{4w|5PNGID|HrLlT}TlU=OnsuAMetB}JO3{JIV^&W#M5xitt*<)7A6*S^y` z$EEhJ$W25@NT_*V>d44Q?`(v`{hfmiu}?eQXS65AWP|w!P()NW_TbT3_Bu?k&pEpP zWMN_XJm^#tD%vwNWai<%l-3m-i?v1_R|(Uyvaqy&lwL(C3i&?Y^0A4|+FidEVCCVd{i*pH`ygJ3k&zKsp;%DO zv}1=}imYAngFO@Wj1JxICeMiHk!7|Xg+VS2trVOXwF>Z6?8@Fz)o|(gCig)>@6i*4S zcVAhW$#I)|uN*ApOBZY1UBX|JiCwm|G}GmuG%P(?%H-KUQ6nzEG+Gzs)cG1!c-hTAxQNiN)MbYKi?#to=STmz+o{Zz0 zBTvs~_l3GnE{qn&Yowb}*?8btIC8w($Awj+Nmf))7 zuPTR!bmjEvA}-%u6iWL(%ZFS^y2*y+&yHTWeEB&`K7K&R;r9!cvikb^kvWY_ON-9x zZ;g%0|NQfh%byR_joAP31?X7gaYZ}1JUl#Ta#t6ZlxQ@TRaZBmKU5~iE62ax$$CMp zai}#t9`*Vx+Cx-SwCCr~Ys=3$u$(^2Y5fDU>TtHa=V~)4xrY7k-sJt@14N zaQw92mqECxw^}lzotA^;@L>)2Z%3_t+I8|m3*|<<4J9c7oy%u|8Bh(eaHL`+^B~@2~9A+#rrthpPx5`oU;D7 zn1c;@d0?SL?SqP2iu1(vi7rtQk>4W^)?#;AQav|*Vvu4#u~*kEd~TvISeHrAMxsV^ z=GC12aBI49u7gGNWLvha(L|R~^{d#PiojsqFhN@>Ink+SbM|Hatc`ijw$YOqN*kM| zQi>DQQ`2tzI6E=)<%{&xP_tZ6P>_F8)lKUr(Pi{JWnVhJ%H;FDbWv-lx4zJ|YE6^b zb@T+uAho+cVS-2#APiMN0XeD7<{}Vrt7ym&Q0{XAe7R|L4balfHhvGBNaO zAOF66`-EJl6SA?rwKX#>l|OlHVJHA{xY5V)bstv%hicP%!wSwO?fX7H^(p$Lm1SG^ z?%gZ<@?;rvnN2*qQp{^Suy}(q|K9hT_HwIDa%eB`KVZSba@^{gQjB)b>CbtrU^YV1 zzj$Ac!*JZ6KQ|^u&U@+gWBb#{&>TB<%zdWAkwqO_VK2v52_Kep)5Z&~)4yw<%88lm z*zN4(6l2krf!Ylz@!j7y>A~baDZez;tV<_k7jq$i-LTdQl|c3lhloh~z!HY>EqizK zAANmB+^8Q_9~!Q-YsDMg(W<8+KuQ@yn3k#c{8BG>-osLmW)}gAy zcFf}4-5+^QpuDP6SzM7Gla>?fFcD2gj})*I=Q& zmz9@SF(oC%1}(bG;6#lbP|J#CCf|Rc%?*L>_`Uzam;Y1Y&i?{pnRzNA&f^AW{7`H7 zdi_;7R-xy4Yc8abpX{wy%C*`i7{hDjUR}z2&glxY!hTtC@J$X_|OiT=ywF%HE=f;nZvO|VOMjZV7O(wjbv7Zw+ z?qOF8(9N-v0j3+Dn6SMdPPq|}8FC#>ILCg-bnIubY5iMSy={D|8#Zhpv-_L{z>4bW z$g(h8ZGYL3Fs10rA)Mt;pFXvp{i9d)lJBCf+t}Tv1ocSQ(frKEQqD$yCMT7QjblbI zT(KwtSy`u|-n?l~(*ZEy*3S9$Ol;V_DOoE!f=p=MV}h%N{+j5Tk_cXb@@DlGhZ&F5BKvaWb^7441Zrj2m};rAHdvs-l)R5k&BXA)*SGPN zpzTT9-b?ZOXa(P13W|8JhhseK)RChYLY;0CvZORblDAT_hJLYTIF07d*x~;MLmWG^ zt0F)7wzszvq--)X($QMr?y9$qnsT!(;`A;qKymrn2+2nLWv;!Jwl<$#|CRT3mv__C zC%k#X5+drF@$MdavtkeHUv_}~BY4fn(;Hwg01ODzo$Yrx_NFFjzhEqJ3Q-T*+igaB z=la5$F#)Y6)LdM0@I>EWNy>_gC9xO^j~GRNjf_|-JECrdJQlNDW=0Y*uT1e=2~w>nW6bSB!p`~<8@Hj&?%c=-p{$U%8U1vt(bn;!2+$}rv)ZGtAWiHI&oMnlsb|xVbNHqT&2H)?M1u-5rSw z8^RM!cO7~0f|g4sPYosQ@8jcx1&=CabiG9Sjk+hmxCSU)b*LzZm6a8rdCL_BA-i`P zsfYRbRc!mJ8!sNqH{c2QR$Eg$1)dabSn>QQJ_emdmB4s>j8(y0NnF^ilNPNfCu*hu zb}@dVfA#f6XIQkgzUt(XxfsZe;iUjF#VP~FQNK-m@%5qeer$&imxv6R|M~tQyMUT< zNzFx66LG_d+{NQlLbKrPR0gg}xd!9B$S z7!H0*jlS@l(w=*bY(2%H$@HnJ#rRpavXux!H5uk;yWLqC4}n8$S4 zy??)%-^JE+vuHHZl?*Lv$}i=As*PnXT<}j$wkl`~6L!2zS`An(tG4E5hxROM)32}p z0ZL9BJTD<(JpQZ9v@IhM0Le%vZKz&OTp2VW9Srnd>GmBv!m-kLz?O8PE@^Mfd;)aB z=ud1uN16}FmEE83XFyfT@Sw5uiU40}cw&nMRC7RmG;<3J@gP%MwePXG<8H_@p>l;@Z5$rz#w}YKzoi-}M@VcK?#NZe9<6J$Dzs4c1kS~#$RQjG zKYR4>VX5%&@WGDU%&PMt*K0x(U-FreR*#XSjJhEs2GDK*PbMSAjMi&;B5WlFR8`3O z=bhDS*5G-vk|D^hn)C`ngn_&`is4w{;`>H!*|{ZpjLmZ&K75!TN-iYzeeT{?!uDPQ zBv*NCtTRi_e+zm_cvL-Zg>Y#GaSyG72M-eQWAmqXro5QxOZe$bi-(lQOqq0=l|j7c zZF?)_p@#5Tb&BlaP`PyE@~bOt85ZeRKdhFi>+)Pl?5zqO`eqcS?BlUFF@liT4>r`^qS}Ox zmIU$0mtIf-AUh>H##ln)&h^R{dJ|fsc@utv&B@8)D~pOK{s%rjipe_pyXolS8XA;r zMbD20W)4mc)IYd?KSizmWFO-q_(&|k^Mgl^pd18}xrz?@Vl;>}YpnE1+)Z|FE(abT zS&qi|z!SW@Rs-&j=^ycBVwvNV;|JrD3rIJ5u=iM!`10&IT#`JhKXDNILArUQ3yYr4 zR^gm1^@uSM1GOK|3R+JLnp!l}$PNAB#JVkX21@ZNRX?=^F9dR1B&el) z?4`A-{+8=#bx2gSLyHj`4{8U%EA9RJ_i+=V4|g58td`CM*+(8*JKm!?QParD#ijEX z!#lg*Ej;J^t?|FP_Rzh2%L6#}cb7hLoE0LS>{DA?Tg!VxAIUZ@%?#Sk&dw72_r!VE z!Pw}BV9vCAkLjebb)}mH)+TBS9AafH)}mozvRb=#EiII=Sji4_@qH-FWFG6E5UAf%b&o3jc(6U45(4j-0BFq=2hrL+h#~0oF z{OAR*;cDp$7sq`h=a!ciwqRNWFSSfGs_O2Gx)@ksl8@Rt=igXU0nGubiWfML^{+2G ze8xdnX`#aK7A()3#S?wtQVu}Z=PIz z^)%xrlFnz(-jXl$aPPiW0GDdjd6{ zeH8ni$938~w-mC-LriJ9ZK9KpC7r89v#^_JW+tEx+KSju|5lHL*j0iZ!POcHnKNs& zVBw*vu6tvg`1wxiPd7xG5)`BTs$2~(U3#b-=2ybwWj(y?uD|qZ2CD6KscY9Fic^vc zms@Yx+NM6(!pqC+>pGHMN-OM;AUTIUZ_yC1TJzYzR%;rVpzH-iT-l-gKOZ+qffpo{ z+lmcg$iHK&nM)9d>Z@Y$+r7!C>ZRdHNOLv%)Hgp~d42ymp!5Dj&CF`mHYlNVfZQjc z947cdRN4ogz=yjrlwbZp%|uK7x#|(cpFVx^_Vul5w*OpG zLg)XGNWagzE<_jt0+-Z9UF2G>3>9sopfN;8I}X_9i!%6SvIy}py!Nf^{@_wl+GH)5vlc+Qw|IK4h=jBqPhw5aaU2cp zo6ly|pkJaVN_Pp_nW5YdBwoB7QWLa!8xGqGGIjgqyi#fJ+U!2r8geJPtq zy$*#G>9I6FH{$u>7M0Z1tFM)vFGm=56}YREPXGRD_Aoh}tlXvG$cn%G)Dpl`l;`}So8y*m@VFQ1~K;uaZF$pO(5!gQsyQX}`D zV_93wZJ{0RhdKjg5;*ZkPmkf)T~$BliQYXIBBHQ1{B5Vk)de=gq!`&ul_Q0I=~H!n z_c${%^ItFoq?zDHwQgOU%+tfIS=NzsCylRKCjhtbhuz+~e*Kki?+xweCr$V)S{uSs z(5_?;KP^kyH0ue02aI%brpvQ0KA#I%Vf@dZcWe{m;|ipC1K!(@c6Rh?91zI$>&SH+ z>Tv8zcTNC4R#DVaP11^`W^^hV(clz#;O%_{2-$+s3>abOu3h#R`HhL1s{Vx3vK~In zj{d)gT`6l|4lF1ENn>1#fO6^h(~0;pbviOGAt@(NfI=b+f^_uF$Lg=g<>bQm2#4{CzmW2CH; ztRo6^VFn^G_%-}?amp*#PT!Hj6%Xxf8_BV;F?Q`7yCi)f06AdZcu=n10lHVth0l@_ z3wJ;=mhZ)w1bOh?$18Pv-In*cQnJ-KS2KQh=C_SRl%fMGX=uEh9_>p0{{5rI3AGRX0c}e^^M>aJGf6ZkL;H-#pbf8{6?;GSCz+ z*4EbcQkg`VZ7S*!_+fx~x9C6g(P;XY0|U@5@bJ4_MUPSPg_VB`m7Q z4AZ`Sv3TY!OF=G9PHaL#t@L)-%!=5SN`Q^KX=vUG*$;xM4iE(b0PEM_plPi~(alwW zP|t7OzTJpcshp^xeE6vhp!p4_6oK>N=Yi)_b5B|6QEiz{I#{@}6*ftcj``THr*ay3 zojLaL7{dA!^v?OoqwU#>_qHUZwES>T7!miuWzqqsJu^ezU$uN=a<+%O^*+r zg9BsAXW6a>S|fjV;~o%rhIkJKvKUVmTr90+&;~xa3*Y92aB(A=4SFDOdR1}aI5T7j zJUW7Hob!O`9317cJKn#4Ke;$jV}|}=3R;Y7;uQmHOq7#JAn zeV>+gGCqheTSiWf_%j`01B6rV`A!NI&9DxQBHeAy22`&f7#liX_0*uU%=_$fU#qmTZk)+?nW_uoHEMu$$*TOGq@>pcI^_IloHl@Q1 zg~IbzG5YczY07t?^v|Cqf=hn~wCSG+diMD7Yc!%?ISl=HLM%PYTi7Y>7T}$)PFbG? z?44_kn*|{8&&|#CV-z)=Vp_tQn8FO>Z|j*}5F;iNVfI zN2N-Rr)36?eF;M$g0@~+Ybi=aZNlR7{gK2ggyU+HC;&7RuUifsI}E>8`SB;)m#@;# zo3pCd)zxtpSj4UYrd*sx&U+uQqV(Ld7(IY<=zn*~^7QT_G{SCGwM!_nZCe`GITR)MKN z%;V&-_DkrccA!@bB2IsvSKA7Y1p385cI?PvftrA>j9T}~uB@H;5M)i-9Hxi>7`&j| z6vKV}@lQe4_w9zeIv^<7ghYkBcoBmScU#F_b2C@wpTinNMT0P6f>L!T)Pqynj8;4a zV^s5*m`4zW(|mK zx<*l#`r^_OG2IAz<$u)WB+gyA5OoL*{SDOOmON_r>2F44CBSUWxos!izsMy7X|866 zx5$kB?`QoMDkOuLTg!C?LuhSV!$yfZhD|cLyfE4(VoB>IV#ereEM#jGp%%qF9Un~ z!5gy=ANp$}l`}21@0&5L6co8%c_{1(B<-bhO(5c^2WT#cn1E1b*fO8{yH4LMF>)wa?v2d5>OG$ zCi`n&@>{%gcCmu&?6rAc1X!h}mMZ2`LxkieC^N=j_gBxLX4fCb6t*I z?#+wYWpmlc01OrVvxZPxI8|49_GH57*|=dtExJ!-s`-f%C(;~-HMY>7x`M)Ie`nJ6 zNPIa9rW$GeLod0r5pai*)hnsMum@uu z+^@zh*`KUSKh~n|9+L-V0>&Zv{6vlP3toC(j&CO49FltY6;5;IP{^J;kHqGshnf=W zL5`|6Wxa*E3CcyZAt>_}Nt!Rv{_I9OwE@!! z(5jxV8vJ$XvO7Fxe^8SGr#p4yHHk`K^s*i~Vup5M+ME>PG1s#f6Q<#Zt#BNWGnrvx zbH8?JXWjVW_d;;%cin!6-v>ipzBIk`>QpW1JqODk8xTs1Eq}scq&>*24Q}Gc80fq8Y*&^f4yzCC!#u)ieFN7VFl$}B3Y9XT&GAM;PNH1jM z2a@nIxu&$V^oF%Hk7+}U-Xi*_S9ElA3L*jV{6o#j0;n0ItlZrF$MXJslwCLl$aAm^ z|Jx!v)XA3IYz@2~3Ciw;Ptdz0|91Dg942OQ z=gu8k$p^c6Ktnb|b&52q2`$0pnNQpD;zvrbZ+vmH!7CkESy{AD&PuTGwicRD#`NK5 z@IPAm^J$wDzDq2J*7W{xdel12{y4h)C%t%u*H4xqSb!YJ2;1)2Fthvw%>N)IlR=rfz;AGBfo>+*`3Me)K0S`jr8*M8#eHe3r9`w13U17 zMpm0@Bo|*35^4QVucg4V>vz6OZ;;VLti(6BQLn8Gri}4e@LFInedy{$Jn1-fPyjbA z;4;~-x0^$ya&a^XEwltk_aVyTn?>L6nt8;E&;?D&B9t}vZKG^_>7&=)$z{#0lNa;s z+3uiAhwrI~Kpx3Os(cG-Myu5r@^pzVjVe+t}p@~M}T6RV!1n@PykL7;_UliQc;+B3CC6^S zAK3f>4Objg2li8^)YR40v$@+}1Q)I>Ys$#TY+S$oi~T=Zlv#lsr?ChHMMdI)M_1F4 z4?-u#0Nq1N`(kLc>!{*eV}Oj-G?S-q%!KzMbJl2j^;-S z0Di!#R0+!wqX$w>L-@ISX$W#aF(pumU!M1&O~JX526w1~yqyHBMCb&bjVjtW zpWeqi9~qe`?Tu=YR5$#8|2AJ5^<0549!wNv_@yzxHY%EdYbYaPke^8$1B_6m<=m=D zBFhu%E+(dfq9L7@5smmGBMqf}>VL&9&HifxnC;^lV~7n=32yxx9JqN_HR?iurs?Q4 zTCXR84^SDI4M_grt4KWL-n~+;^uAi+;N$DRB~xK(7EyeBb5wM6S2g#lzgK1%yjm{J zz`(#5JH)g#EfyPsSaCORQqrxubU@pX8l@7w~SyLdAXVK9{XSO2p(QcRow;Vc<7Cxvfy%T!QBy^{axAM5<-G3fF$4obYSx6)E zdbk$5%pvXqn;xo}@vo}c`FctDG7d*Bg+-vO#96||KW7ytF_;09rdb&MYk$l;emMPgA5 zaS%9YEYN7v(Ljbf^EJ7-xlLQ%g#)!7MyM&cPxj))dI%HtQ10Pq0}Nd-?Z_ec%6-Uc zJ)U5+acuIByxLai3|dtXwQxv2jRTGK0~}c zNlVKWaf3>;%>nv|29olI&q^o+^bG{}{^b18_kYwFwUhn^Kv$hmfL`F?rGu9gd%etP248gdkCw_5 z{X)!WRAE_7jhv;WrN+Cf|Dd`OL1#45%kFO~dAKhUaE}ejF;=9J1g7EP5IRFMHS_yh zBrz9$i^J_Z=e1d0Lqj7p$&q-}xLnMs!?+VXNe)iVONid3J-Cgo{~My;Bq1{FYssvghfxoG0Bq5oV&FG?RY z0w+Yh20(5y=N5LEjDfzuDdPgkf~*=e&{7dEP-a!wN6m_IqK4`2y{LxC&%5M|e}2AC z(}QRu@Z8zuuI)%)zeGqxP3?u)@~pnkLYSR})DUdObh`nC+~H%o1qnoqL=BVtPBF+- z2^j_?gIv7CUy6-^0^**VpMP1=TOAS)gfvK8!W#1pn>N*d+9sZWmTbHEi9Q-7;@+w# zVM~|9F87%bAfUHcm?7vK0)cX`PRS#^P!HjwGWEjdPgh_Wku6Bnlin$lB%S;?@KrXH z7erDOcXPC3LhJdhI?Y3;a3Uf33qp7P-8>B#t6a#RLU$|uDVm3}Bb$(5Y8{M4#|WG| zP)JbYNz}?#z|n-}vZqH-H3#rX+K@zDU?NSy(*O4A{x}9;EQAai85)_58#mTL4Iud{ zXona4;omf11`s7mrH>e%B+K{Y`z-#<=Q3&RLpkIBQot$+PogzsV5~{gN`0*~_8ZL^ zbtk7RWq(P3vJZi$&0w4I*9>_>Ct?#6R7cZLIirhZI4l=xE36Qnr35<@48uCdqLQjhI_FfOB-yGov6f;;K{YcrS+XaM`2;$@db)0@sfk#}1BA)Wifz*o( z%|L4o4~8O!aYdye8Wyz4m&ksY+xwwg#3F9Rze}*0ck?;2&+cs4l`s;FCrWxuainmt zfK?M7$4LZdd?aroZB8OISfcYtu%r)`V(D=6G>>S&nh3xk>xkh(#)`c=D5Oh&_&uBa z>%dV4I8^b31zBZGeN~^m$C@y(_P|{MHje={n};md2($nROFc~;)UfK!T^2;dq!0uk zwo^7I3JB1r0XTqb5UuCYAToJ#>3X!p!t1bg?EiWO*P7w{`yXn?FZENc3P1yJ(|U8= z{YVH67{;ObOm=OksH=Ik+LJfvA{O7?7x$)6Gr4^o5k%nq^*u%xX+$JdN)c*7chUEm z`HJa6EHZ4ViIcbqdii`8dwYBmQqNDae`AzDdkG_+N8dsoEFMgpc|?XR7?YtIR*j>R zu|X`1AeWAfn8ub10tX=UH2S>pHb5Fo6LMY7bPJvNt|ZltOok@I1c(rh@R*lSl;c&9 z0ekOUz4-YGU`rjK!#KH&k;7E;{g1hbXM=?6WPju%Q*3KeYKJeqc#Q;*Z3t2`#^~06 zV>Y%<|5{gziPey)$$^hVwUo0E>z=ZpeO^M?@7)WTG1!_MAZsfNT`OsNpVq}sghFtG z$Z-Ktb(heEc}c6`Y*1BVy5N_qB^*KqXWu?A>-uQ{+r?Sy_}JR%H&IbXp=IHV|4|^I zTv=He@z!!OURfoXgKCkyYDbtXA_k2@T%_0#v#nJ0t$21c5@?fTZxB{74o#w_P+Uo= zL!>dTdn3cHTabsuO*n#^Cu>r%#YAE;sQMv`AQ#RbXk=v6mhahhD!6BUd`O=nGuW1y zeE9OK7ywM=m@8-e3F|}e<&e`CK`4WiJB*bc2);>R-8j)?O#w4~@a=GdE& ztwZcev^+wy8)FsEgTFcF;C$SOM5bu1jVSyTVBfR2{P)snZ=$2gLxYma-?n8x|D`<~ zDrszxtHL}M$IjbE7{UpIyqPk4{V@D8sFzR7xkUPa0FX538`3D4{ke&F^%*l0JSfmHkB|W zNG#H)RO{6k>wOmG1T}dQ;~+-Aj+z5fcr>Y&wojiwf8O^k!QZZ@d{6F-=YYoM8~P{u zuaECKho0+Bj?$5=7+D;0bi~dsjl=^nkWJx>v1p-*N1c^TlWu++tCQy>j~nVe@DYFp z=%^V3Zrm*tU3&l_wr1?PU!$YRcQ@@dMSPpa-lo9a1=GZ~!|SU&MDqGg$MRn5ddwA@ zE||y=t@%}*9F*CPqm=>>OCH0C68vEk3D5CA?pQd_?EHKj&KbG%g*eiiVvup6ASxn$ zgL)975*mLJG$lbh9J5Fgu=yoTmJ(_eriTLLVawKDjHX7EXu44iW3X9vOK^S=6JYE- zR7$wIY{aL;Nc$?#<@Pi>1HE&r=q>XTL3We<^0}p$!satO4!Vx-KZnc#Q7LE)5k6AF zl*G@*h!mpQ+hsPo`92R8#JLYhLB#@mppGeOSNM+M09OPob5!Id1UJYNMS#s*p74n{ z=!{U6ho7bNwLtNKWM{yOGayHYQl%Au#?nw2B8~R{8eNS|O@i1=Z@^_*JI3om(U?vR zFX1%Na~uINbxfosXO0kS!jx}*BlGlcLo<2%3qY?S{2X<#%yWpw(N)q_`?`L7x@V8w zPApHdO&~W-G$2?yPgJbpdAsbRx&>}YW|8zlc3<+oxa$b=dSOMrVs*1tppt0d=#R{$ zOJK6Y=3=G6hcG%Z%2knI<0?OQ?i|vp%;dmRIyvwJIAI=o%qS(B$#YR2=c?WXNvubn z5yyQvLqAkMdHncs@*D5M!a_uJ%rEm?2AROMS!8g?N8D*#U68%+_|==wxiqiyI{j^6 zo%1QBjHxZH^DbyxDHkMqBmwR0z$fClUfI49vM-OTY_k0>2)+1Q=MKt^3P_|=KsRRJ zi~e!;U?Db)!{|a=rj=uBFwgP)=eqkZ)zPft%QSJ{YpoqnmGdY&b$2Et@) zjl!#f1%|*g9QLMH`9|_Q%Xi}Hqn+`j6m^jEU8t`(%cUt~HK=s*Kaw;xg= zMv=ps7w9n@LP8n<8GKgXUOagc-sQ0%%cYU_6$Zh;pL8)ero>tIIaUyZ96nB04v7Ja z;slP>%vcEay)h;g;q;+9xyhMA9D^iY*%BN=V)C#soVsxu(%DypDa_?pNU$E`iGb7> zj;FmJK5T%wOg2FN{9yc6T{9fpf||BgbnNpk+(sx=3uYOT#U#uH1s)>i-m0~N5u^%P zVmBxt5J6ByO_+Ex_LL&QAV5P5jMypyB-CmY=Kj1-QXdtYh>qKj4#1x*><>{Bpwz7Y zYxred@dQxo`JX<=>|l#Q0&EP=U~>EO4-pbkCc%3{E_=Q9*K;5BvC~ixCQg%!mr2S6-^v>)qpHjP$&aI{TW7c0vf+2CMHpH zFf-D*MOoiWY(*{e*1dbn}OZS0^1}2|NpdO{e>0%vi~~^a37`vN)Ops_xg_Y@J(KLkdjA}qX$@<+}tUN@sX zgPiy;D*OgpCkb0k1tVAS3G>|oCWwmfx>a^@wfEpV%i9QWUWTH0uxvGs-&VVXgtC#Ayj|et>i1S%Vgx%wJSKj%I)xDe{HiAMP64(95E{s-M}!jk zC$%5qgz|^0JOB*j>xkZezPstU{=e@^v=}I%;M^CzR(;GWxhP ze9fwM|8LU97nxig8FfE3uB=+zrOxt1+ooRPj?~(eV`n5AxOXj{^v ztve!_D0wlKTbd=~Z4?*t+2NhS2aOH#)g`7yCl)>(o3XBH|J51nx@&O4BR@}PSdiDY zAhkWns=D&kK`0;0%nW+xH-X-huP1)*BJUz$Ar7&bPxua3Mfh#lWT;y#4yn7B9-)bht=211&LckHP3k?3g0L|^0V>>M$U ztjG{2O}3UF?8gw!_=E(qpNtz~c5L3fd9t->;-|WfQLmkZ#G8$}3~`3yOOw0MoCbw9 zGc!l*Z2~(OgPhf5Ps?m=ZEf71#e=RiB(z&$jja$fv%*~luK(PmgwRJnn!itkJe1l6 zS}wjF`}g}`2p<#_+*e&)O)KVh3M$02x8E`>_T%nvVz+^3Pv#fd1umyI^r^ypb@DVym| zu1B9$e#tE=%79Z8-abCJqoP>R|BhsfP+-b48JyS$0mvI69jyYleNfQ*;f6t@yK3;p zWv{r~>~>7mVjvfJNH}(b4I2<|xOxBn2FiB;0H?|R9VF8txPmnSVl+z%1*7IRJ{9O2XcZxJLOmcUHV8@v9mt6=n>2M*i5 zckex5x@(xh+t;3nPBCj$F)&~OD!i@t$x9FSLb>_z<5|4ZV+{!xUy8Quq!KA8BWG|$ z+-9$wC_x_aDrD}6@z&M?;E-a#@vlkR!pJ`qRaS1sMtvdTd}7@;M$Yn@G;JpAmw>gR z(9NE_gCOkNJBQ3V)uv6Ge!v{N&;w5#Nwq8oI>&mj+}#03SXnn@WMq5{kl*7WFoH(o*wBH?vp)QaY#N2|W{tgM;kOVnp zL53YW@>@d7j%dKbxQl5Qi4Jk02eny>?G`YSw3+_Y_M+})wa`hd%XhTN5*%sWPD69p zv>T`$4bRETsaIj$&*5)(dGlxbeU`Ps1(Ue6S0MBkm zqE+d-zW!}!tOrh>+>J2r1$*NBjJ292!s69fw!cTZ`r*053r`H;3)3r z2ZHQMy#oXH0c7{1fxb*s<4ZHC`22Y-X-SaYBY*~vh7LBejMk}GSm~k}lf{pC5AwnU zg4k@yiZ0bKdC4K|^XJb+2#vhAgzS6De(>4P6*MA}&s~I}aG@R|Dm83VT0Ya=hz0p# zf=Oz_9$sEvutFw5Nt}c78XO)i_leYmgp|}r0P^>i;G0o4qe24s zJZxEoIqc#;)xx7Od&moTtZ}XY1;HMGOY{*k#Z{SBz{4VEhS^u_5%C=EK?m7FVJTg2 z?cgAV+f%M?!~QHAM@caQP!7{Jgv)9oX^(-*9vCar8I{=D3CcOW$G3Pj1^VG9jET7o z7R=0b@kW-%+an^F4Gay#kuvt5!le6BQE^U1g~r^%Vig7L{}T?Q?{?I}+=M*tc@?LA z5%xL*n?1#GROeDK{}#$0aHS8WrPR~Y(|mrJX=%H02&M>A?r~sXG@CdcV9BTDLBpBdNhk?jWmw6)F-AMrT@e zL%;XGy?*D89gSX?iJv}xj3_sGv;0+s_2|(JXaGtsgWWQIetztIiGxC%d<7b^#`P+O z1uT-mc^CJm=ez47{1-5c{)?PYt>Ot%UE`IawF=+V!w1a-1NQcl{T5VMC=ME8Ju{+B zDS{oe5yY*y;|ZZ1`C@N;n!PCB(bEvsVI@&JRLKDtXk()TM@~1$1IoO5MIY58GG~Qd za>L#}QN9wQ9_97y`^Q3ryiCrchvwpJxn@rX53ore5<>UuIJo5E<3^m{%63NiVvXXS ztE%g0X(iN8W7R_^$rUY0a87YMPoH<(ynxwFp*{YIMTkksUNNUgR1F zVIAB@ka0bL`f*;~PcT}{7CkH_C8@~1#h8YvodGZfKCDI0D%iT3S`L1@1r2x*dZ^aSJ#l-_PotgoIKXbPK=A%1co^9qh_V zO5fq0*>4Y6HhJo(T<|^sV?$65!Zx|^QQ$Pz3cM|B6=r>uYE+s!hFe6Oh%Eyb<+#C@ z?QZbx{p}%YNP%0%xIwQt1&l$k0m1|T5@%xz#;1r{vxOWA7N)^N#klZhi$x51HQ4q& zd+vh0OGb{V!xLB)bjb_G&mnA-+6{VUWO94>2R5v6n>E z0&d3-BFuLVTJ0Ea7%v!Gg<<1jwd*VQs79U>4_;~D16kqVv18kA9rzBkNsI7DdGlja z+XcK6>>9M;2lRDjq#wr4-P&6u#KZnxV#JVc@}a$0=aiFHa@nu(*raS&__;?Uxo7- z;fzlt=S{rt;SR&DqfM*j*c~OW!N^J~M1ic|xbZB`s+zKTVN@OD=H7~lQwo90G)mTD zJPDIV559$3$~7wVGCfO696)g?YHQn1q&@@T^lbMoy!&hOe*Vn>EBEf+rGoW;6P7nK z3(H#UaUu1vK^xFF7wdiNUwYimdcJ4(ZoHUt_~r3IbcHo| zXTpQ=aU0Nuk6*sTl*>tr%hXSSZ@aYe&NPNc2L%U5;s~QXgH!;rm5)eF#x4j!GsT7Y z(|meBqbkC{#r4^H-DjuUPDr9J3yRVl7&*v#0N z%`{%zGpJ@hmth=R><>Et5-VWUDS@#)B=pRa#Zs4D=`ikR%&Xr^b8Dh z0TSRwyuXeA)h>WvcKH1dnc{!`y+ay)=ije^^}-q$MSs8S?)<+H{{Js3$o~(&{$WGN zA6Urq4{!hH0%NM74K^BktU0rNt^Hxnqeo-qF~k1-su9|?>nUATjVkuszhD1txAQ+W zqyO==o(K&XALZgo^tp~ie?QZMY$_~_D2HiOWK>=rJ#r}w(1>BntpaOAjQ;Q`}NycX(HR~Hk}m4NW$qrZeQk){g9pw-=j7fR8=!x!@QJ{cA7E-mrq zVkMWfB3h*P{rk09mjj0%FJq{EYHZwtq=v2by|WT;R{Wj=^N}6{-{8#Es|RpC)&~Ub zD$+*MmoG~Na%L8nv8-%gmnma+rgz+-NL`W0zsT6do~G);}FEar4HgH zH}~PrzeWpsV1w6$PL6i{y*)(reQ^>rD<@}Am_8p$D<^Px3J4x)V)mzM?(pcBTm_T2 zXIR%rnK`&a)a-@k_S*V{goHn6AIsdJ%&qt9~B+A$z z?=uag$~OSy9VUXG3oaSO6>9G?Yz!JsTUl9o{KogwjjI~s@8QJOPekD6H%F;}E;k>< z4qJ-k$;2CPBaA|HbQvz7BY;V6J`@=j9CXrTCn9xk-ixubZ$$o9Qb&i80I{N?(=Fe= z#l8~b<=utMVU}&*R>%Wiajpc?*u|bnVCI%oqiuLi!&L)=J9zQO58Owx!&2zvaAwy} z!0#_*WgoydS5cq}?cBobvM_Byf9eJW;}S>!z>I=V;noPn4P#7?;m|(;e5Y=FrvS=& z!yUVdsrN2DU6~^Zk^}<+0Mc@D_b~baeZO~iV?wMc8+gPZ>;<6+uRX!bBN)JV@8sv_ zH;rlMIqiVgBMDQ(cpt3b&p_QzFjn;S^&vgbgE0uv)6+L-WDbh#)4z0SGr=MNo+nM} z)3NDbcR+nzZ*cwkM?m4^{AoJ4J*)Z5;xgRIM_#5%Zr;ad6Aa^Z7cX91#N@*%^KD9uVG(H_BK7t?UWuQLAz+(YT7wdbP z?!0X_kX25y#acmE@B(_o;(x$4Mr?DZY;sva!(ubV0} zLC*9YC?;X!MMV_+#7D$Vu3~Hmy7yQe;^(IUU_#7f1Na~{Y?-nFTto`IibeE~+nT(Bl&I1K&dw zwi{{R2OaMoEFu)}JyeArhG~Qe#wZWoGsCG{Ko6Vl%a<=VV`GmN;0+%{2?RP;&IrIZ zh8PPigc|SxN0|?fbmsd&Fxv&e;WTuHx=9=>bn5mMIPUOUjnr8~!)~O?hwd?u zRdKpa;ic;KwRtPc3$!PVx1%v_Ff%iw-M8-`z!!pHsJX9p^%vqEj=_LFjR%N$HHP?o z3~TZRj9DmCk`sXRde9lj&Es7F)D(mWKSSd?Zq~$EZ1?haMY$nPInkj$&nhVFL_yv_ zE`_Mg z3m*VKq4#~xbDoG3&A8uRT+)74-J+MKRZGx`A_ba1ZtIhO3AH0PzQq(|^FXK?$FQAb^Hai`ZJ(m>Ri! zOg(AKc%0?ChI0_o($Y7Bg7zT*1;9jK*nDeIWi90#=Oh zXgN*qdH2h>xJMwzeaX-&Lx|{e2vt*~w}NsgJhN(fw+fu^|EliHqk3H5fB!M=$X2Fg zC^A&aEK^FPL?oq9C>0weLy3|UBBVkRLX>Eb3=Ilps&)w_6%9zGiH!;obzZmcuzu_O zasE8(9BZ%r-Cs(d=XvhuzOL&%-7}$AO*d}kfeXd%NXvLBxZ*{}m{06zo55|n2(C6_ z5+Mp5_)W2?U%YrxqS#h0#0=|Srz!}CfqEf=rM@7+0+@R?s4|ET=}psK4$Dzf7#kAy zlh698v2hln*I|S9yh?EB3rP2;fq_B#BsYMygpc-~q*RWbI58cu(y@!9Mvua6xaILBtUqT?Qbjidhr$@78{fPxh1si=O4<-%98k!V5+?y5ejRk zS*2|0BA^ynfIgdfD{bGuxF#O6m`l^+0ly!5wd)`--JDPzYZFYu)7dO5sKbV7y(Cjm z#+e@SdY3udT%aD9UuC#(`~U$hj^t5|#ZZ4}kdC6TQP@}TQN=Uk;j=Hss1lCenKtYB zhT#2H%X3DxUP^TKH9WO^iJh(ORpiCIV@ucdL*DWSB$I$YvQ!o#!wBF-XXKw}vX#o4 z-gkBUiYpg2?kE>M1SAl^jj+j@Y~}qB=(330D3%L~E)Fgpc#GC+0=3wOjZ*9s2OHo* zRx-g0bO)&5ylN;t1CmiMXcaa8Q0%Gg`O?f~8}IBDky-1$!xr5G9~FE~gcF1y;--k} z0D6LPK@2$~ZRf9Jb;y>xo}7PiyPe?_=Xc$3^aN?zix+tQ8O=V--e%jci>8y{24xZ5 zt*_EkB9lz5p07A{}V0`A`@;;G%J;KhmAKPrkwWj2L}ZLN7F|?(|_NboC@% z?LT^vnkW3HhO61p1!ls78MEDS4IB^5>Tvx`~2PZ>3) zl4YDgyn!*O(qXkm^XFfs>O|;SfsR}4&D&y4Qk9H&R0!;c)LkZ2rUNKd?6G65GzSc~ zdrWoz(3g6iX5EQn`P4YY|czHrSgmIx%7_v5{>Ei zS%D(9D`J71ghbKGV88uyM*WYKMKs_q3lLR-BKVP{n$Y#k{Yb)a-Hts4eg@vteP&c zv9;YTc7C>qo9N!cZbwE+P2Ny?_@`cSczAf^Von^qEvh6)RoDUz>LTdOvhb}`iwi~ma-w&fS;|nG2acVtLp$=mI6)d=^ucuy ztP4#{&Ns-U5lj1qjo5HC1BU6;DOL;;CT3>djP&Z+YHhb@m5+{#-oC_afO+~PurCpn zd;K3uRbKQR`f1fz2XkQLee_{jA0!g`Faa4Ys92;AL{(Mvv?0x{`SgD0)KsEsohC-q zPE1xdTQR7msW#l{*5ZUQ6ScQ)EKAu7b^)+FGuAacEG!o}KMn1{IzcU0RYf%YxK2aE zkiVJMG~xJboIlhVn!=f}wK4fPG_o4lJbW917>7tSnB{bO+Btqh-aX2U_V*Dg(5Jny)b@h1~vS(8EyqZozhg(fi+#rODxiqM_yzl1a^DJMr%sO~nx3PN&}Q~W12n8Y-cQE$=Y>ob=Ha79cON|1Ro@Mad_mgvbM`vmhK8v+ z%L_Eu-568x_N_+BqGtY%<>+gxT0!MzqSt!+ z#E6-6JlL;?2Uq>`kJX2orcgahc&9zdxw+;651YB&^jqrF@F!*-W)=DsWDUqWK) zlIG}{9Y{`Zd)wZAa#G}lESE2kr4Qi1#143qxX4&fBisBUr_hVP-xh@dbUXd4)v4th zOiYf*}@#F8_pGYKW@poZ5oB z=n3TDfl%gjM8zmR-p5zkfCc=q6f#%yX;wzZjvp_~{&-@4YjRU@-rM)@5Aj-3?;>xz zchuT>{a$vqVU)@;V!Dl9BK*p|dIYsh;B-4*0=D)RP? z2(`r9x3%QuYN z*RWw%^t7SH_yLaJUiz-Ga_#NM2`_^wEDs@e3N}0_F7D4BcVhhNgT;;V0|tnG)^O<8 z;N!<_TYr9+>efw980iP~ty3J3@|UAyY{TTtT~+(%D2NvzrnLRSwv_GFPZWOl2QrGs ze)eu){)@M_wtf^8_IRSjvS@!lbbzs6!!KrTfU=pGBDc)K;yeqwe{}#x#GZE+c(6I5 zLFZO>V2WNGUH>{wM`YcbB{G@b{{AjRXj-hMS36Q`Vb|vtqGlQ4_e)Q2n0}mHyiofv zH%=kgD<~hO=>E~i;Ivqx88lJDtCL;ok3T!NHmEVOzdN~H3c@MS$jTT>4E}YM)UYt! zH65<^mzTf$@?}3!rY@(sH7Y}&IaK-Y%cUHr$DMMX{#HFc*h`$s3%2U;eGm4J}V(Eleca*oqopVx$6p|`&fDYa!GJLjCA(_{2VV${@fO0XU-Ey)-em} z)kAx3oWdwzI_Q*PA{HQMPsmP4Hf~r)VT_AtVABc1P4j|qwY6u#;og_s>WiJyT4iPB2s*L&}GzhhqN4@zx zE#168`+FB_d`+%yZs=nEJem(jt?qZ;bP>~qTHrZEm-DDCa8}@iv5IHP2u6tRQLhGH{%yz#0oVD%d z#Ho|E#Co>1HUni1`xgFuCyYRUXX&aJ_7m73{&6XI@}xV<b!Y|Z~q@0Q>+MnW=FKw8j)It}!XwO%D=jR3kGepchC37UdCe(a)*fH_P%li7 z@@#b{!tEgpO&`xvdY=J&vq&q+?oksHR-39ZZrnXq=vZ)>^@ycE%f|sb0$qp%`S|H< zTu6iX@pFTG8;O?}Tn{>fdw z2eKM_qLp+5pQGsWWa9*Gw+4FUX$%JcLQG}l4SMn1{|`XAOJV)M%O?B zy4crGqXq7(gdXNWH4jJ$hjrA{N`Ar&z49*$w7TS%+`Q)rYlgSFGoQoDe2KmqWm)ARS zGquGr2sfIEE!YQ-YKH{pYGUF(S}BlU!O&gCTHEk;wY~j8Amx>;(@op9Wx~%t-!Mi? zYaqtdU1!cHynp{*Xy)(~u7rEFOpkk<+SJ(i0C`1^o;}<1!=_A`a+ggS(DFBkH}u0t zrWfJ{@1jThH#;M+MnOD-um*KbDoOLyA@^r2b%^YnnM@$8)V9`dX`(V+UY4|A#8#t# zKUa(gYvHHx8g^iF7K(xN^iJ@TmLFw?Z2vqq$@TIiXJ3^l_Z10=i5V>9&434@xn?&+ zD~PC0#>(T1c~IztEB=X(2*;U7KPavwPs6-mos-d(4Ho7W>_gav`V?jIstTM##Nyc> zl5_MgV`zHAWB_0)0AZVr8)e+x-SIO2$!(cQ**bPHbfSYCLQx?ld%(}cD#uXX3F=`l zPQ(b6zeF#|5*?_$y*KzL+`(Sf7>4%)=#qxw+^_8;ys&SW`yV1aeN$5-#myAkCrUrC zq6eX#tx-kY1ArRE$WD%v$-^I10Nb|KU-q6S4*hJVA(?N8Sn)r9?qG6JLaI!D(WWh1p1+({ zU;O?N-^N~j@71f+VmGz4&Z`meoSeV-L+hg_PiCXmS~cDxp+55Qwe$0WGO-Zv0Wt(@ z`HU<493L^e-RJ{iLJZsLHww0bK9Cqy+& z98PRu`pTE|j(V`gh!hsZT$45V5Vg)=L+rMA*@EMjdg*RYWGWGe!kl^YV(J&7b8jCP z^P_Qbul)6 zn$msm{?kKPJTWtdFx#?0bhpQ-n z${=)xhaXjmy^HJJg8;kGWxrn2p$UwUpKalf$BZd*R@F0FCzjuKg@FOj1u{f} zvot*&VNF92Uvxx#MOaL8k*2Bj;!y(?qLQto9$bJPb1OWt~ zqrp(&C$#IUOWAi$OxZkyD4Bh9YtL`Sih4_k^F*r!E?X47`DP&VB77sZe(e65e)mds zzwwn{msd`Dx(6`zji7q##~vFi6Lk_cB8<)vUb~03%AuyR8FS}$5yJ%ZjU66K@rD5_ z09pC6vxGC*;xXNet2A|1-iY5`k+&(`!!3DTct2!2p1(3HZ;tuqG!jA3+|raxd-L|K zX+>hClUrRBaSZ??g8oroSkfspk^lCUs*{!l^Ar! zu9N_-?FA`$1BMG(TCi z1I02m{XU`MEEiA$uRDsJur-~Xr9XVy(eTz*G3+d#f=)zG2rjbbWpCZ~%$c^`!)4Cs zU^`^OB*yMB9ny4xA^1MMmC#kJQEigpl!v_y4Gq;-mi%(_3t=HVDBV5BIcFBS<6x_4 zmT0rOk{c(S*rzk-F?oEe$Bc-oa&zw#{j9P7WMoqNk~|ID&^I6ngo7 z=4ClBrDMsRFah`KB$yhK6Nad(OH*SRoO@Zhf0)P`v)1^t(p$ z8?@qEte_G_nel6AdwG5~+)9ZTDJu}2<%BFAIjw`zo$WO5<4?xc1fvIT>ieu-8-xyYp!c3WmHARP_gn072DFhQyYbctGLM;g zob5xYyc9`WNfGVrJ?;_PYrC1UxdgQ`OsM{PRiM4^>FKF^1kX<^0A=@pDV!R0K%~*c zZ@{s9v|cpw1E=qh2UR1reN;&`jB(IE1iyOZ}Oe&zYZ9!t}gsULTrOuN`M`>J*$%Icu&NAhwMFl`m`)#G-G9S!U9wX z@sCDnObw;z*unS;-^}IM*l{4^&CShb$va%0-kX@b*uklLVdZjZRcjsG=$nbxrlS#b zX_!MvzE}m|mg$>K%Ib&ndJB2UKd;kdKQE7?AA^*?=qR>5;e@G)qy5z?D+A5YQYlW z&wytiuXwlO$TND~gAB96!ovo*=h7t=!SI|)oRsuSxC=0~zCm})V3ENk3+X2ce*4@0 zu((ST*PqP(A&^BD9xRL@U5v1z@zM%i36h+5g$?eIUX;&m|TPi5K{*d zqQjw~0Vo6Qu$~=lWEecphAkLhugtzNu&-kO1{IStBDMbjF5{KFUMIBOCrRif7uE*! z_Zav9%qkt1gkBYImq3fS!TZ1y1t!SVFz*SbIyCE`7sxA%viKSt4$@T=JZRj-X1e zu9&4K2u`3SkxGll?*XaGL-pX!`RpL*M_w$3o<0^-; zYYNJkcYytR_$dR~jzU6#3PoHFbD)s5oQE{L>Cuu{`Gbg@PJ*7~rrXf$0lXjr7NXOo zx3=Xb_@x0oXz4gvbD6-L?~))r)H#xF(^h*Uv99SM13%LE2Z_(YkGu6}$OM~EbY*3_yxe)-aB#~{!`6EqRfgrZQ z$%fV&Kin;15H2om;l#)L=*33ngiJP$`pLeg%huINa-aP&s15+D6C1r?R1B+VR&-IO zc*C|XjZK82*h}{-g%+z%0p{){?BYj?>o*fe-p(tRL7HL4AZi1D*=~eZo51HIu(Ex{ z10=$8_9C=b3*zK&BL)>&i4+UU(1q&XUQ*7K8YYiYVmOj5y{bgWp3qOP3EVd8K4S`5 zaoOL`E;cbyqC%U?Pu_5Pp_$nQRWvg_fh^PoC8c3(zBwlB@n<4X#zF?Cf~vufs)P}IuvGD&2cYh{#8*Y$ zb}D1X7^4$^CEr6()%lycAtYv? z8EijGwqrRy*suz-AmD-jgPktnxb@r+JF}42n;;-^5tP`Feh`)-ay)n&6HQ;zT69Dv zR?g}+tit3Y%d72;TF0+2@Hgp{uaP9{UH4Sw-+ni}>bukjX<@j-nWWdac6;~2cb2iW zQ1~wqJdXUTGEm9gNO#)9QUZ8@d2G? z%3`-2iL&}&cEt0=ieX06iYt{6!}KgpbW|5PMnS@ocQq-gFMbI9DwiWZ99E?pZ`^5@ z+y36~l=OC=EAl@soZqa3{O?Y)bagv|2bg>p&O5ps25w`s;3G#^xB@I+Jb)&n&wL}J zy=>PeW$dZL1J6I5-ZH{e1`1Epmo$VV5>Xs}O+JBVFRB2~0%j@f#Nq{q6sdNa+g ztXSmmBRlXrH*Q!x7pMO6`?-UolYYSh?G=mR`Wdo+wQeYkVrbmW-@*az>d}v~vU-Rc zOQRK~kfc^T(3q(Vifhj^8ZP?fj(|)N*TJK@0#t`Bm@-q=n>2`fpxtzui`A@&k6cO3 z282#1s+w!A@Zi+ectAa2yi^-=h$3}glBwsO~=CAIGwVVW8Wf>C3A$ zYwvULr;1P>nO!m39sYP+^y7`q1Z4Gv^*h$HEc*7>Vu$;r?+vdPiTJ|x7F(6;>j&Gl z_SCY-{JiVnd|4@o$feODnX2W#V^d4846f#EYD-T~Pt0NlPGG;kNU^cWEhO1qIeCD{skCi!+q2)bC zw&_RRVRl2S&aVlUAIJMN%%Yn5hq!KkHY!Ca=v*NTv`^Sve!mmasClW4Q|KbHh&8u0 z>;9w&Mi&&K98g+3N)Qcju=PDAdDyQYL)P&=XIU-Kg_s}6dczK$$5vO<*iv0EI z3{-k9ly^(dVB@?+ z#fsslTKCDIG-Y$2o9ggO1C*IoueqgVUSSLWQgqNWw%hPoCnbB-*$ZmcgbK#0T1RG0 zUD_)}>yIS~q2??4^6L140s#W6W_O%9gRr96Zu0?)kAP`VP)7}E3oXu>6~E3PP^*&u zm=d{UYU?R*?MB(LGB%33Plw*=zRv;oGpKz3Q9_z@6QW9ld>jx5`DG7_iT;)td-?Kn ze-$#fs?M#V2;*bM)T^Qr^Yi(I^XI{L`t|JDGnbs21lyb2J}QDGA#FSKAIdyq?#1X2 zrUjGoz9q)l516K@8NGb*iWqyH4qi8IIQs=g0%g~^M%#y=0znUyKEnaQ0`zD)Z_BXi zEZkNDFQpe2E{=7{tr|CSILH~Gn%W;lG4|GL8$twojj7C=E8AEoG09v`95z9lk=5*GYNl!`L{B z7Ck}AxIE>`c7VGRAhkeWncpvst)v|v-ikw+@wJu!ucazdKhRd}D51ZG;5liJDZtSY++H)~DF8i>= z?78`}k_cousO0R8mc6vX#{`$D_oZ{^&YVK9!Ot$tbNV~E>@8p13o~wP?|#LPO0s?3 zsZ7pB?1R=JgQg9*tQ`g7?LTF=SY;1sYRB#T*=%3ZdZ|hYl0mKUFw>sHm#PKtl%AdN z87p(XIGmb*#uV3ZJHEht=aW(Nbpnio_wBQq%)lOW-Gb-Ollj6m3Bu#m+s+br0HJHT zeuXZ4T+gq8RIZ@y`$CV0gv1oC8O#X+{^-EUnQV@3(=#Iu%Fxl#o$o=`GOsX7EzZy- zAW+9mC0v>BJ0B>o}D0%I$d*p$&HG9;cp zw?i~%z9HO5S#>75f%_A)!&p=MH~$UtuAktDDO!lU8ag6NxbSV-l7xg02`@9uuU)-5 zgHZnLVMcue^4gn!TqA)R+Gp94to*^uMJuX!D4#^G^J4oFhi4XwW_|2`si4hGw0AL5_tDw>*4 zJM~TpnPS72bEj4+4xvc7Nr8`SjmOL}Wb9a3!E?SWxd}UVTE|mS9E_@#bmPXN?|YV1 z4>4OLco&BFVPoW=UXC_`0Ci+}9u3g&(6WFc zFR$(K%n))T&()1rZ8pIcgl@{1lwrf%{IRY0R?*>$Zzp$ulU7#xZnS?7-`iXA`wur! zlP^$LzhQdaH*U9TSGBHc-MaPKJwvI>?kj62w^vdcdSG{*zSY8`2WRBYks5!i=HN@q z(U+Drws`iqvSm%enRh?h+C0c(bpAT6`h8NhmivpHPh*xGZLt_BH;!%o^AslWK4gOR zRZvh+dYX`wks`WywBN6buIm%8KoZb7W9S|!jw z`GEtyGb`@lP`yT6JTi=PeD+mooJ4ow!i95Cn=ckTt@teDi(^9`FN`u%YVoe$sPkk) zQuviCqu<(b2&!mu^IpE}BKg$vi$gnw>Wq9n*=M`Q$x4wID!~Wf+ECf0{VvILk00Lf zKEsri_c^YZY++#VherQDkqq}-|FY&=OCBey@Jh}+zwlga?5MD!{u`AhCmS`~L@>J& zm3c2k-=c(X+wc6HlcNU?`=h0$og`GplMs{AHqLNhsSt#re-zmtu`2OKqORBrv3$mr(?@OtOA(nD-cS^D=7G_y}Y(Mvb5QaRC!T{sBiaR_X_qBc_(sJdL=3h9Qcqm za!n}vTtBzJ`>o!U?gD-HAoikwn$rCzmg$8qlcWWxbl#)1>gVP!U%p_}k#W;1mhYzL z8RtK#^P>X^<^PPk8{pmy38|v>DnGx&o_k)=(P}$>{pcgfEh+JZZX~90>Z(`Af`d7u zSAj!3WL{OU5>js6;t-TW^b|%Ghk^rr{_bb6K=IkDSKZGn@BvG?SH6i{%MFw59UN>X zg+MU*bt!3!)f!fMGiTNXU@yAkwUP{G6}CeGFxEqPv0&?yen^lshrUS%6HaRr|*u)h?a+w-YI+m9~}7`0|a zncZJCUQ}@Owm>51FYEfg{%Alk94ROL+VfWfqKDJ3J@na9aZ+mU<@ZtImr57|NPbN5FK-azPEo0@B z1G_NmbX8uW6WDliQ}S}!Y9UHIXQp=K*s+e>HHJFWznZRgZD+XGw$zJ-F$KBy3Awdf z_(x!YK=W!R-*ECoNfJNnyqOv*$cX(J?ZZOCJ$vdn*xQc+uayLILtWP@Xt-fVd( zF5T3$J@?peb_Z@V4^g~(9>)7Gj7{1hd&suX6Ah?mS* zw5S($5guq49#LSHH{JPZO&7IQuR6AG_q?!B0Q-g!J^#AV&px65;+Q@FE1NXW9If!_ z6`(b&cyr}DB}Yd`utbYZ-Mc8Eu^pcL^mX>EofAhl%B}-fvc4#}*zBnq*sXi_gXQHm z>=4726ov9`&fn4eX=d||>Ef*9tXFl59#=PZ@p+(>^Oy5OX-hgRBer2FUjNO~nj;om zTI8yw$im&D^W8Vt)vATr;&W}j3V)loRa{#|k#xuQG z#>x#jV@y+POB!I<_3NfwQu(GS-j=ccuJ^`W${%stYU9Jur*o_)SUJC_a_SP0ITV7o z$Bm#y73W)J_djfv4tcD6Yg%DVO=NTP=WyA#?gQrg&F=MRR<|AjQ`H^@1pTDpZ7p%` zpVoQ%9g&*iXc*wz+BPn!`J=33x8u%n{>q&*Clva>))}h3Ev`uY^ny$5`MusQuJ`tz;Mg3S?kFt5s%HfS<1snHx15}m z-@ks%Tl!zG^uZo-|MfxJuCiaeZb(>~E!b&~F;fh{x_i$Pj8AHX?SDzzare`@|NHOh zkuD|u-=ws0_jUgh8E10nDj(oqzj_ZH8UG*8_5Z~$Zy5AI@@c~USw2TaX5U=>1$tL> HSML5_&&H5% literal 0 HcmV?d00001 diff --git a/course_UvA-DL/introduction-to-pytorch/continuous_xor.svg b/course_UvA-DL/introduction-to-pytorch/continuous_xor.svg new file mode 100644 index 000000000..12bfd7f81 --- /dev/null +++ b/course_UvA-DL/introduction-to-pytorch/continuous_xor.svg @@ -0,0 +1 @@ + diff --git a/course_UvA-DL/introduction-to-pytorch/pytorch_computation_graph.svg b/course_UvA-DL/introduction-to-pytorch/pytorch_computation_graph.svg new file mode 100644 index 000000000..19c488782 --- /dev/null +++ b/course_UvA-DL/introduction-to-pytorch/pytorch_computation_graph.svg @@ -0,0 +1 @@ +x2abc3yViewer does not support full SVG 1.1 diff --git a/course_UvA-DL/introduction-to-pytorch/small_neural_network.svg b/course_UvA-DL/introduction-to-pytorch/small_neural_network.svg new file mode 100644 index 000000000..065197aa0 --- /dev/null +++ b/course_UvA-DL/introduction-to-pytorch/small_neural_network.svg @@ -0,0 +1 @@ +x1x2Viewer does not support full SVG 1.1