# [2.3] - PPO (exercises)

> **ARENA [Streamlit Page](https://arena-chapter2-rl.streamlit.app/03_[2.3]_PPO)**
>
> **Colab: [exercises](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter2_rl/exercises/part3_ppo/2.3_PPO_exercises.ipynb?t=20250330) | [solutions](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter2_rl/exercises/part3_ppo/2.3_PPO_solutions.ipynb?t=20250330)**

Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-uk/shared_invite/zt-2zick19fl-6GY1yoGaoUozyM3wObwmnQ), and ask any questions on the dedicated channels for this chapter of material.

You can collapse each section so only the headers are visible, by clicking the arrow symbol on the left hand side of the markdown header cells.

Links to all other chapters: [(0) Fundamentals](https://arena-chapter0-fundamentals.streamlit.app/), [(1) Transformer Interpretability](https://arena-chapter1-transformer-interp.streamlit.app/), [(2) RL](https://arena-chapter2-rl.streamlit.app/).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/headers/header-23.png" width="350">

# Introduction

Proximal Policy Optimization (PPO) is a cutting-edge reinforcement learning algorithm that has gained significant attention in recent years. As an improvement over traditional policy optimization methods, PPO addresses key challenges such as sample efficiency, stability, and robustness in training deep neural networks for reinforcement learning tasks. With its ability to strike a balance between exploration and exploitation, PPO has demonstrated remarkable performance across a wide range of complex environments, including robotics, game playing, and autonomous control systems.

In this section, you'll build your own agent to perform PPO on the CartPole environment. By the end, you should be able to train your agent to near perfect performance in about 30 seconds. You'll also be able to try out other things like **reward shaping**, to make it easier for your agent to learn to balance, or to do fun tricks! There are also additional exercises which allow you to experiment with other tasks, including **Atari** and the 3D physics engine **MuJoCo**.

A lot of the setup as we go through these exercises will be similar to what we did yesterday for DQN, so you might find yourself moving quickly through certain sections.

## Content & Learning Objectives

### 0️⃣ Whirlwind Tour of PPO

In this non-exercise-based section, we discuss some of the mathematical intuitions underpinning PPO. It's not compulsory to go through all of it (and various recommended reading material / online lectures may provide better alternatives), although we strongly recommend everyone to at least read the summary boxes at the end of each subsection.

> ##### Learning Objectives
>
> - Understand the mathematical intuitions of PPO
> - Learn how expressions like the PPO objective function are derived

### 1️⃣ Setting up our agent

We'll start by building up most of our PPO infrastructure. Most importantly, this involves creating our actor & critic networks and writing methods using both of them which take steps in our environment. The result will be a `PPOAgent` and `ReplayMemory` class, analogous to our `DQNAgent` and `ReplayBuffer` from yesterday.

> ##### Learning Objectives
>
> - Understand the difference between the actor & critic networks, and what their roles are
> - Learn about & implement generalised advantage estimation
> - Build a replay memory to store & sample experiences
> - Design an agent class to step through the environment & record experiences

### 2️⃣ Learning Phase

The PPO objective function is considerably more complex than DQN and involves a lot of moving parts. In this section we'll go through each of those parts one by one, understanding its role and how to implement it.

> ##### Learning Objectives
>
> - Implement the total objective function (sum of three separate terms)
> - Understand the importance of each of these terms for the overall algorithm
> - Write a function to return an optimizer and learning rate scheduler for your model

### 3️⃣ Training Loop

Lastly, we'll assemble everything together into a `PPOTrainer` class just like our `DQNTrainer` class from yesterday, and use it to train on CartPole. We can also go further than yesterday by using **reward shaping** to fast-track our agent's learning trajectory.

> ##### Learning Objectives
>
> - Build a full training loop for the PPO algorithm
> - Train our agent, and visualise its performance with Weights & Biases media logger
> - Use reward shaping to improve your agent's training (and make it do tricks!)

### 4️⃣ Atari

Now that we've got training working on CartPole, we'll extend to the more complex environment of Atari. There are no massively new concepts in this section, although we do have to deal with a very different architecture that takes into account the visual structure of our observations (Atari frames), in particular this will also require a shared architecture between the actor & critic networks.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/atari-demo.gif" width="200"><br>

> ##### Learning Objectives
>
> - Understand how PPO can be used in visual domains, with appropriate architectures (CNNs)
> - Understand the idea of policy and value heads
> - Train an agent to solve the Breakout environment

### 5️⃣ MuJoCo

The last new set of environments we'll look at is MuJoCo. This is a 3D physics engine, which you might be familiar with in the context of OpenAI's famous backflipping noodle which laid the background for RLHF (see tomorrow for more on this!). The most important concept MuJoCo introduces for us is the idea of a **continuous action space**, where actions aren't chosen discretely from a set of finite options but are sampled from some probability distribution (in this case, a parameterized normal distribution). This is one setting that PPO can work in, but DQN can't.

<img src="https://images.ctfassets.net/kftzwdyauwt9/cf6fdf49-ea9e-489d-eb53eceeebc7/03dec4ea90925c03dea2ee6c4976e921/humanfeedbackjump.gif?w=2048&q=90&fm=webp" width="200"><br>

> ##### Learning Objectives
>
> - Understand how PPO can be used to train agents in continuous action spaces
> - Install and interact with the MuJoCo physics engine
> - Train an agent to solve the Hopper environment

### ☆ Bonus

We conclude with a set of optional bonus exercises, which you can try out before moving on to the RLHF sections.

## Notes on today's workflow

Your implementation might get good benchmark scores by the end of the day, but don't worry if it struggles to learn the simplest of tasks. RL can be frustrating because the feedback you get is extremely noisy: the agent can fail even with correct code, and succeed with buggy code. Forming a systematic process for coping with the confusion and uncertainty is the point of today, more so than producing a working PPO implementation.

Some parts of your process could include:

- Forming hypotheses about why it isn't working, and thinking about what tests you could write, or where you could set a breakpoint to confirm the hypothesis.
- Implementing some of the even more basic gymnasium environments and testing your agent on those.
- Getting a sense for the meaning of various logged metrics, and what this implies about the training process
- Noticing confusion and sections that don't make sense, and investigating this instead of hand-waving over it.

## Readings

In section 0️⃣, we've included a whirlwind tour of PPO which is specifically tailored to today's exercises. Going through the entire thing isn't required (since it can get quite mathematically dense), but **we strongly recommend everyone at least read the summary boxes at the end of each subsection**. Many of the resources listed below are also useful, but they don't cover everything which is specifically relevant to today's exercises.

If you find this section sufficient then you can move on to the exercises, if not then other strongly recommended reading includes:

- [An introduction to Policy Gradient methods - Deep RL](https://www.youtube.com/watch?v=5P7I-xPq8u8) (20 mins)
    - This is a useful video which motivates the core setup of PPO (and in particular the clipped objective function) without spending too much time with the precise derivations. We recommend watching this video before doing the exercises.
    - Note - you can ignore the short section on multi-GPU setup.
    - Also, near the end the video says that PPO outputs parameters $\mu$ and $\sigma$ from which actions are sampled, this is true for non-discrete action spaces (which we'll be using later on) but we'll start by implementing PPO on CartPole meaning our observation and action space is discrete just like yesterday.
- [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool)
    - This is not required reading before the exercises, but **it will be a useful reference point as you go through the exercises*- (and it's also a useful thing to take away from the course as a whole, since your future work in RL will likely be less guided than these exercises).
    - The good news is that you won't need all 37 of these today, so no need to read to the end.
    - We will be tackling the 13 "core" details, not in the same order as presented here. Some of the sections below are labelled with the number they correspond to in this page (e.g. **Minibatch Update ([detail #6](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Mini%2Dbatch%20Updates))**).
- [Proximal Policy Optimization Algorithms](https://arxiv.org/pdf/1707.06347.pdf)
    - **This is not required reading before the exercises**, but it will be a useful reference point for many of the key equations as you go through the exercises. In particular, you will find up to page 5 useful.


### Optional Reading

- [Spinning Up in Deep RL - Vanilla Policy Gradient](https://spinningup.openai.com/en/latest/algorithms/vpg.html#background)
    - PPO is a fancier version of vanilla policy gradient, so if you're struggling to understand PPO it may help to look at the simpler setting first.
- [Spinning Up in Deep RL - PPO](https://spinningup.openai.com/en/latest/algorithms/ppo.html)
    - You don't need to follow all the derivations here, although as a general idea by the end you should at least have a qualitative understanding of what all the symbols represent.
- [Andy Jones - Debugging RL, Without the Agonizing Pain](https://andyljones.com/posts/rl-debugging.html)
    - You've already read this previously but it will come in handy again.
    - You'll want to reuse your probe environments from yesterday, or you can import them from the solution if you didn't implement them all.
- [Tricks from Deep RL Bootcamp at UC Berkeley](https://github.com/williamFalcon/DeepRLHacks/blob/master/README.md)
    - This contains more debugging tips that may be of use.
- [Lilian Weng Blog on PPO](https://lilianweng.github.io/posts/2018-04-08-policy-gradient/#ppo)
    - Her writing on ML topics is consistently informative and informationally dense.

## Setup code

In [1]:
import os
import sys
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules

chapter = "chapter2_rl"
repo = "ARENA_3.0"
branch = "main"

# Install dependencies
try:
    import jaxtyping
except:
    %pip install wandb==0.18.7 einops gymnasium[atari,accept-rom-license,other,mujoco-py]==0.29.0 pygame jaxtyping

# Get root directory, handling 3 different cases: (1) Colab, (2) notebook not in ARENA repo, (3) notebook in ARENA repo
root = (
    "/content"
    if IN_COLAB
    else "/root"
    if repo not in os.getcwd()
    else str(next(p for p in Path.cwd().parents if p.name == repo))
)

if Path(root).exists() and not Path(f"{root}/{chapter}").exists():
    if not IN_COLAB:
        !sudo apt-get install unzip
        %pip install jupyter ipython --upgrade

    if not os.path.exists(f"{root}/{chapter}"):
        !wget -P {root} https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/{branch}.zip
        !unzip {root}/{branch}.zip '{repo}-{branch}/{chapter}/exercises/*' -d {root}
        !mv {root}/{repo}-{branch}/{chapter} {root}/{chapter}
        !rm {root}/{branch}.zip
        !rmdir {root}/{repo}-{branch}


if f"{root}/{chapter}/exercises" not in sys.path:
    sys.path.append(f"{root}/{chapter}/exercises")

os.chdir(f"{root}/{chapter}/exercises")

In [2]:
import itertools
import os
import sys
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import einops
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import wandb
from IPython.display import HTML, display
from jaxtyping import Bool, Float, Int
from matplotlib.animation import FuncAnimation
from numpy.random import Generator
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.optim.optimizer import Optimizer
from tqdm import tqdm

warnings.filterwarnings("ignore")

# Make sure exercises are in the path
chapter = "chapter2_rl"
section = "part3_ppo"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part3_ppo.tests as tests
from part1_intro_to_rl.utils import set_global_seeds
from part2_q_learning_and_dqn.solutions import Probe1, Probe2, Probe3, Probe4, Probe5, get_episode_data_from_infos
from part2_q_learning_and_dqn.utils import prepare_atari_env
from part3_ppo.utils import arg_help, make_env
from plotly_utils import plot_cartpole_obs_and_dones

# Register our probes from last time
for idx, probe in enumerate([Probe1, Probe2, Probe3, Probe4, Probe5]):
    gym.envs.registration.register(id=f"Probe{idx+1}-v0", entry_point=probe)

Arr = np.ndarray

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

# 0️⃣ Whirlwind tour of PPO

> ##### Learning Objectives
>
> - Understand the mathematical intuitions of PPO
> - Learn how expressions like the PPO objective function are derived

This section is quite mathematically dense, and you'll cover a lot of it again as you go through the exercises (plus, understanding all the details behind PPO isn't strictly necessary to get all the benefits from this chapter). 

At the end of each subsection we've included a box which summarizes the main points covered so far which should help distill out the main ideas as you're reading, as well as a couple of questions which help you check your understanding. We strongly recommend reading at least the contents of these boxes and attempting the questions, and also reading the section at the end describing today's setup.

Also, an important note - to simplify notation, everything here assumes a finite-horizon setting and no discount factor i.e. $\gamma = 1$. If we remove these assumptions, not a whole lot changes, but we wanted to higlight it to avoid confusion.

## Policy Gradient methods vs DQN

We'll start by discussing the general class of **policy gradient methods** (of which PPO is a member), and compare it to DQN. To recap, DQN works as follows:

> DQN involved learning the Q-function $Q(s, a)$, which represents the expected time-discounted future reward of taking action $a$ in state $s$. The update steps were based on the Bellman equation - this equation is only satisfied if we've found the true Q-function, so we minimize the squared TD residual to find it. We can derive the optimal policy by argmaxing $Q(s, a)$ over actions $a$.

On the other hand, policy gradient methods takes a more direct route - we write the expected future reward $J(\pi_\theta)$ as a function of our policy $\pi_\theta(a_t \mid s_t)$ (which takes the form of a neural network mapping from states to action logits), and then perform gradient ascent on this function to improve our policy directly i.e. $\theta \leftarrow \theta + \alpha \nabla_\theta J(\pi_\theta)$. In this way, we essentially sidestep having to think about the Bellman equation at all, and we directly optimize our policy function.

A question remains here - how can we take the derivative of expected future returns as a function of our policy $\pi_\theta(a_t \mid s_t)$, in a way which is differentiable wrt $\theta$? We'll discuss this in the next section.

> #### Summary so far
> 
> - In **policy gradient methods**, we directly optimize the policy function $\pi_\theta(a_t \mid s_t)$ to get higher expected future rewards.

<!-- 
|                     | DQN                                                                                               | PPO                                                                                                                                                                                                                                                                                                                                                                                                          |
|---------------------|--------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| **What do we learn?** | We learn the Q-function $Q(s, a)$.                                                         | We learn the policy function $\pi(a \mid s)$.                                                                                                                                                                                                                                                                                                                     |
| **Where do our actions come from?** | Argmaxing $Q(s, a)$ over actions $a$ gives us a deterministic policy. We combine this with an epsilon-greedy algorithm when sampling actions, to enable exploration. | We directly learn our stochastic policy $\pi$, and we can sample actions from it.                                                                                                                                                                                                                                           |
| **What networks do we have?** | Our network `q_network` takes $s$ as inputs and outputs the Q-values for each possible action $a$. We also had a `target_network`, although this was just a lagged version of `q_network` rather than one that actually gets trained. | We have two networks: `actor`, which learns the policy function, and `critic`, which learns the value function $V(s)$. These two work in tandem: the `actor` requires the `critic`'s output to estimate the policy gradient and perform gradient ascent, and the `critic` tries to learn the value function of the `actor`'s current policy.                        |
| **Where do our gradients come from?** | We do gradient descent on the squared TD residual, i.e. the residual of the Bellman equation (which is only satisfied if we've found the true Q-function). | For our `actor`, we do gradient ascent on an estimate of the time-discounted future reward stream (i.e. we're directly moving up the **policy gradient**; changing our policy in a way that leads to higher expected reward). Our `critic` trains by minimizing the TD residual.                                                                                   |
| **Techniques to improve stability?** | We use a "lagged copy" of our network to sample actions from; in this way, we don't update too fast after only having seen a small number of possible states. In the DQN code, this was `q_network` and `target_network`. | We use a "lagged copy" of our policy in mathematical notation, this is $\theta$ and $\theta_{old}$. In the code, we won't actually need to make a different network for this. We clip the objective function to make sure large policy changes aren't incentivized past a certain point (this is the "proximal" part of PPO).                                     |
| **Suitable for continuous action spaces?** | No. Our Q-function $Q$ is implemented as a network which takes in states and returns Q-values for each discrete action. It's not even good for large action spaces! | Yes. Our policy function $\pi$ can take continuous argument $a$.                                                                                                                                                                                                                                                       |
-->

## Policy Gradient objective function

> Note, this is the most mathematically dense part of the material, and probably the part that's most fine to skip.

Let $\pi_\theta$ be our policy parameterized by $\theta$, and $J(\pi_\theta)$ denote the expected return of the policy (assuming undiscounted & finite-horizon for simplicity). Let $J(\pi_\theta)$ be the expected return of the policy $\pi_\theta$:
$$
J(\pi_\theta) = \underset{\tau \sim \pi_\theta}{\mathbb{E}} \left[ \sum_{t=0}^T r_{t+1}(s_t, a_t, s_{t+1}) \right]
$$

where $\tau = (s_0, a_0, ..., s_T)$ stands for a trajectory sampled from policy $\pi_\theta$, and we've written the rewards as $r_{t+1}(s_t, a_t, s_{t+1})$ to make it clear that they are a function of the trajectory. Then a theorem sometimes called the **policy gradient theorem** (PGT) says that the gradient of $J(\pi_\theta)$ is:
$$
\nabla_\theta J\left(\pi_\theta\right)=\underset{\tau \sim \pi_\theta}{\mathbb{E}}\left[\sum_{t=0}^T \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) A_\theta(s_t, a_t)\right] \quad (*)
$$

where $A_\theta(s_t, a_t)$ is the **advantage function**, defined as $Q_\theta(s_t, a_t) - V_\theta(s_t)$, i.e. how much better it is to choose action $a_t$ in state $s_t$ as compared to the value obtained by following $\pi_\theta$ from that state onwards.

The derivation is optional (included in a dropdown below for completeness), but it is worth thinking about this expression intuitively:

- If the advantage $A_\theta(s_t, a_t)$ is positive, this means taking action $a_t$ in state $s_t$ is better on average than what we'd do under $\pi_\theta$. So if we increase $\pi_\theta(a_t \mid s_t)$ then $J(\pi_\theta)$ will increase, and vice-versa.
- If the advantage $A_\theta(s_t, a_t)$ is negative, this means taking action $a_t$ in state $s_t$ is worse than expectation, so if we increase $\pi_\theta(a_t \mid s_t)$ then $J(\pi_\theta)$ will decrease, and vice-versa.

So this expression is telling us that **we should change our policy $\pi_\theta(a_t \mid s_t)$ in a way which makes advantageous actions more likely, and non-advantageous actions less likely.** All pretty intuitive!

Note that instead of advantages $A_\theta(s_t, a_t)$, we could just use the total reward $R(\tau)$ in our objective function (we show this in the derivation below, because proving $(*)$ with $R(\tau)$ instead is actually an intermediate step in the proof). We call this algorithm REINFORCE. It can work, but the problem is it leads to much higher variance in our assessment of the value of actions. As a thought experiment: imagine you're assessing moves in a game of chess which you ended up winning. With $R(\tau)$, every move contributing to the win is rewarded equally, and so we can't differentiate between excellent moves and average ones. The advantage function is a lot more fine-grained: it allows you to answer the question "was this move better than what I typically play in this position?" for each move, isolating the contribution of the move from the game outcome as a whole.

<details>
<summary>Full derivation (optional)</summary>

> Summary: we can quite easily get a formula for $\nabla_\theta J\left(\pi_\theta\right)$ which looks like the one above, but has the total reward $R(\tau)$ instead of the advantage $A_\theta(s_t, a_t)$. We can then use a bag of tricks to show that we can subtract baseline functions from $R(\tau)$ in our expression without changing its value, until we get to an expression with $A_\theta(s_t, a_t)$ instead.

Let's go through the derivation line by line. Denoting $\int_\tau$ as an integral over the distribution of all states & actions in the trajectory, we have:

$$
\begin{aligned}
\nabla_\theta J\left(\pi_\theta\right) & =\nabla_\theta \underset{\tau \sim \pi_\theta}{\mathbb{E}}[R(\tau)] \\
& =\nabla_\theta \int_\tau P(\tau \mid \theta) R(\tau) \quad \text{Expand integration} \\
& =\int_\tau \nabla_\theta P(\tau \mid \theta) R(\tau) \quad \text{Bring gradient inside integral} \\
& =\int_\tau P(\tau \mid \theta) \nabla_\theta \log P(\tau \mid \theta) R(\tau) \quad \text{Use log derivative trick} \\
& =\underset{\tau \sim \pi_\theta}{\mathbb{E}}\left[\nabla_\theta \log P(\tau \mid \theta) R(\tau)\right] \quad \text{Recognize expectation} \\
& =\underset{\tau \sim \pi_\theta}{\mathbb{E}}\left[\sum_{t=0}^T \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) R(\tau)\right] \quad (*)
\end{aligned}
$$

Where the log derivative trick is a rearrangement of the standard result $\nabla_\theta \log P_\theta(x) = \frac{\nabla_\theta P_\theta(x)}{P_\theta(x)}$, and the final line was reached by writing $P(\tau \mid \theta)$ as a product of transition probabilities and policy probabilities $\pi_\theta(a_t \mid s_t)$, using the fact that log of a product is the sum of logs, meaning we get:
$$
\log P(\tau \mid \theta) = \sum_{t=0}^T \log \pi_\theta(a_t \mid s_t) + \sum_{t=0}^T \log P(s_{t+1} \mid s_t, a_t)
$$
and the latter term vanishes when we take the gradient wrt $\theta$ because it's independent of $\theta$.

This formula looks like the one we had above, the only difference is that we have the total trajectory reward $R(\tau)$ instead of the advantage $A_\theta(s_t, a_t)$. And intuitively it seems like we should be able to get to this - after all, $R(\tau)$ is the sum of rewards at each timestep, and rewards accumulated before time $t$ shouldn't affect whether $\pi_\theta(a_t \mid s_t)$ should be increased or decreased. Although this is very intuitive, it turns out we have to do a bit of work to prove it.

Firstly, let's establish a lemma called the **expected grad-log-prob lemma**. This states that the expected gradient of a probability distribution is always zero:

$$
\underset{x \sim P_\theta}{\mathbb{E}}\left[\nabla_\theta \log P_\theta(x)\right] = 0
$$

Proof: we can write the expectation above as the integral $\int_x P_\theta(x) \nabla_\theta \log P_\theta(x) dx$, which can be written as $\int_x \nabla_\theta P_\theta(x) dx$ by the log-derivative trick. Then we swap the order of integration and differentiation to get $\nabla_\theta \int_x P_\theta(x) dx$, and then using the fact that $P_\theta$ is a probability distribution, this becomes $\nabla_\theta 1 = 0$.

To return to our policy gradient setting, not only does this show us that $\mathbb{E}_{\tau \sim \pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t \mid s_t)\right] = 0$, it also shows us that $\mathbb{E}_{\tau \sim \pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t \mid s_t) f(\tau) \right] = 0$ whenever the function $f(\tau)$ is only a function of the early trajectory values $s_0, a_0, ..., s_{t-1}, a_{t-1}, s_t$, because this still falls out as zero when we integrate over the distribution of our action $a_t$. This means that in $(*)$, we can actually replace $R(\tau)$ with $R(\tau) - f(\tau)$ for any such choice of $f(\tau)$. We choose $f(\tau) = \mathbb{E}_{\tau \sim \pi_\theta}\left[R(\tau) \mid s_t\right]$, i.e. the expected return conditioned on the trajectory up to  the early trajectory values. The already-accumulated rewards $r_1, ..., r_t$ cancel, and so in $(*)$ the term for any given timestep $t$ becomes:

$$
\underset{s_0, ..., s_t, a_t}{\mathbb{E}}\left[\nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) \left( \underset{\tau \sim \pi_\theta}{\mathbb{E}}\left[R(\tau) \mid s_0, ..., s_t, a_t\right] - \underset{\tau \sim \pi_\theta}{\mathbb{E}}\left[R(\tau) \mid s_0, ..., s_t\right] \right) \right]
$$

but since the first of these terms conditions on the action $a_t$ and the second doesn't, we recognize the term in the large brackets as exactly $Q_\theta(s_t, a_t) - V_\theta(s_t) = A_\theta(s_t, a_t)$, as required.

</details>

We have this formula, but how do we use it to get an objective function we can optimize for? The answer is that we take estimates of the advantage function $\hat{A}_{\theta_\text{target}}(s_t, a_t)$ using a frozen version of our parameters $\theta_{\text{target}}$ (like we took next-step Q-values from a frozen target network in DQN), and use our non-frozen parameters to get our values $\pi_\theta(s_t \mid a_t)$. For a given batch of experiences $B$ (which can be assumed to be randomly sampled across various different trajectories $\tau$), our objective function is:
$$
L(\theta) = \frac{1}{|B|} \sum_{t \in B} \log \pi_\theta(a_t \mid s_t) \hat{A}_{\theta_\text{target}}(s_t, a_t) 
$$
because then:
$$
\nabla_\theta L(\theta) = \frac{1}{|B|} \sum_{t \in B} \nabla_\theta \log \pi_\theta(a_t \mid s_t) \hat{A}_{\theta_\text{target}}(s_t, a_t) \approx \nabla_\theta J(\pi_\theta)
$$
exactly as we want! We can now perform gradient ascent on this objective function to improve our policy: $\theta \leftarrow \theta + \alpha \nabla_\theta L(\theta)$ will be an approximation of the ideal update rule $\theta \leftarrow \theta + \alpha \nabla_\theta J(\pi_\theta)$.

> #### Summary so far
> 
> - In **policy gradient methods**, we directly optimize the policy function $\pi_\theta(a_t \mid s_t)$ to get higher expected future rewards.
> - Our objective function is a sum of logprobs of actions taken, weighted by their advantage estimates $\hat{A}_\theta(s_t, a_t)$ (i.e. how good we think that action was), so performing gradient ascent on this leads to taking more good actions and less bad ones.
>   - Note that we could just use accumulated rewards $R(\tau)$ rather than the advantage function, but using advantage estimates is a lot more stable.

<details>
<summary>Question - can you intuitively explain how the advantage function influences policy updates?</summary>

The advantage function scales updates; positive $A_\theta$ will cause us to increase the action likelihood (becasuse the probability of that action will have a positive coefficient in the objective function), and negative $A_\theta$ will cause us to decrease the action likelihood.

</details>

## Actor & critic

Unlike DQN, we require 2 different networks for policy gradient methods:

- `actor`: learns the policy function $\pi_\theta(a_t \mid s_t)$, i.e. inputs are $s_t$ and outputs (for discrete action spaces) are a vector of logits for each action $a_t$
- `critic`: learns the value function $V_\theta(s_t)$ which is used to estimate the advantage function estimates $\hat{A}_\theta(s_t, a_t)$, i.e. inputs are $s_t$ and outputs a single scalar value

As we discussed in the last section, estimating $\hat{A}_\theta(s_t, a_t)$ is valuable because without it we'd have to rely on the accumulated reward $R(\tau)$ in our objective function, which is very high-variance and doesn't allow for granular credit assignment to different actions. In simple environments like CartPole you may find you can get by without the critic, but as we move into more complex environments this ceases to be the case.

Note - unlike DQN, **policy gradient methods are able to take continuous action spaces**. This is because we can have our `actor` output a vector of means and variances parameterising a distribution over actions, and then sample actions from this distribution. On the other hand, our Q-network in DQN is only able to take in states and output a single scalar value for each discrete action. This will be important when we look at MuJoCo later.

You might have a question at this point - **how does the critic learn the value function**? After all, the loss function $L(\theta)$ is designed just to update the policy $\pi_\theta$ (i.e. the actor network), not the critic network. The critic is used to compute the advantage function estimates, but these come from $\theta_\text{old}$ in the objective function $L(\theta)$, i.e. they don't track gradients. The answer is that we improve our value function estimates by adding another loss term which **minimizes the TD residual**. We'll add to the term $(V_\theta(s_t) - V_t^\text{target})^2$ into our loss function, where $V_\theta(s_t)$ are the value function estimates from our critic network (which do track gradients) and $V_t^\text{target} := V_{\theta_\text{target}}(s_t) + \hat{A}_{\theta_\text{target}}(s_t, a_t)$ are the next-step value estimates taken from our target network, which take into account the actual action taken $a_t$ and how much it changed our value.

> #### Summary so far
> 
> - In **policy gradient methods**, we directly optimize the policy function $\pi_\theta(a_t \mid s_t)$ to get higher expected future rewards.
> - Our objective function is a sum of logprobs of actions taken, weighted by their advantage estimates $\hat{A}_\theta(s_t, a_t)$ (i.e. how good we think that action was), so performing gradient ascent on this leads to taking more good actions and less bad ones.
>   - Note that we could just use accumulated rewards $R(\tau)$ rather than the advantage function, but using advantage estimates is a lot more stable.
> - We have 2 networks: `actor` which learns $\pi_\theta(a_t \mid s_t)$ using this objective function, and `critic` which learns $V_\theta(s_t)$ by minimizing the TD residual (a bit like SARSA), and allows us to estimate the advantage $\hat{A}_\theta(s_t, a_t)$ which is used in the objective function.

<details>
<summary>Question - why do policy gradient methods require both actor and critic networks, and how do they complement each other?</summary>

The actor learns the policy; the critic estimates value functions for stable advantage calculation. Without the actor we wouldn't have any policy to learn the value for, and without the critic we wouldn't be able to competently estimate the advantage function which is necessary so that we can compute our objective function / understand how we should update our policy.

</details>

<details>
<summary>Question - why is the critic network's loss function conceptually similar to the update rule we used for SARSA?</summary>

The SARSA update rule was:

$$
Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \eta \left( r_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t,a_t) \right)
$$

which is actually equivalent to the update rule we'd get if our loss function was the squared TD error to our $Q$ function. Our critic loss function is pretty much the same idea, except we're applying the TD error to $V(s_t)$ rather than $Q(s_t, a_t)$.

Note that SARSA differed from Q-Learning/DQN because the latter also included a maximization over the action space - we were essentially doing policy improvement and learning the value function for it at the same time. Here, our critic loss function is more conceptually similar to SARSA than it is to Q-Learning/DQN, because the policy improvement is coming from the actor network instead.


</details>

## Generalized Advantage Estimation

We've got a lot of the pieces in place for PPO now - we have an actor and critic network, and we have 2 objective functions: one to train the critic to estimate the value function $V_\theta(s_t)$ accurately (which are used to estimate the advantage function $\hat{A}_\theta(s_t, a_t)$), and one which trains the actor to maximize the expected future reward based on these advantage estimates. 

A question remains now - how do we use value estimates to compute advantage estimates? Here are some ideas:

1. We can use the 1-step residual, i.e. $\hat{A}_\theta(s_t, a_t) = \delta_t = r_t + V_\theta(s_{t+1}) - V_\theta(s_t)$ just like we used in DQN. The problem with this is that we're only estimating the advantage based on a single action taken, which is a bit too myopic. If we sacrifice a piece in our chess game to win the match, we want to make sure our advantage estimates take this future position into account, rather than just thinking that we're down one piece!
2. We can use the sum of future residuals, i.e. $\hat{A}_\theta(s_t, a_t) = \delta_t + \delta_{t+1} + ...$. This fixes the myopia problem, but brings back a new problem - doing this is pretty much just like using $R(\tau)$ in our objective function instead, in other words we're looking at the entire future trajectory at once! This leads to unstable training, and an inability to credit any individual action.

The solution is a middleground between these two: we perform **generalized advantage estimation** (GAE), which is a sum of future residuals but geometrically decaying according to some factor $\lambda$. In other words, we take $\hat{A}^{\text{GAE}(\lambda)}_\theta(s_t, a_t) = \delta_t + \lambda \delta_{t+1} + \lambda^2 \delta_{t+2} + ...$. This is effectively the best of both worlds - we put the largest weight on the next actions taken (allowing us to attribute actions rather than entire trajectories), but also we do take into account future states in our trajectory (meaning we're not only concerned with the immediate reward). Note that $\lambda=0$ reduces to the first idea, and $\lambda=1$ to the second.

Note that the fact that we use GAE also helps a lot for our critic network - in SARSA we were minimizing the 1-step TD error, but here we're training $V(s_t)$ to be more in line with a lookahead estimate of the value function which takes into account many future actions and states. This helps improve stability and speed up convergence.

> #### Summary so far
> 
> - In **policy gradient methods**, we directly optimize the policy function $\pi_\theta(a_t \mid s_t)$ to get higher expected future rewards.
> - Our objective function is a sum of logprobs of actions taken, weighted by their advantage estimates $\hat{A}_\theta(s_t, a_t)$ (i.e. how good we think that action was), so performing gradient ascent on this leads to taking more good actions and less bad ones.
>   - Note that we could just use accumulated rewards $R(\tau)$ rather than the advantage function, but using advantage estimates is a lot more stable.
> - We have 2 networks: `actor` which learns $\pi_\theta(a_t \mid s_t)$ using this objective function, and `critic` which learns $V_\theta(s_t)$ by minimizing the TD residual (a bit like SARSA), and allows us to estimate the advantage $\hat{A}_\theta(s_t, a_t)$ which is used in the objective function.
> - We use **generalized advantage estimation** (GAE) to convert our value function estimates into advantage estimates - this mostly avoids the two possible extremes of (1) myopia from only looking at the immediate reward, and (2) instability / failure to credit single actions from looking at the entire trajectory at once.

<details>
<summary>Question - can you explain why using GAE is better than using the realized return trajectory in our loss function?</summary>

GAE is much more stable, because using the entire trajectory means we're only taking into account the actual reward accumulated (which can have much higher variance than an advantage estimate, assuming we already have a good policy). Additionally, GAE allows us to credit individual actions for the future rewards they lead to, which is something we couldn't do with the realized return trajectory.

</details>

## PPO

We're pretty much there now - we've established all the theoretical pieces we need for PPO, and there's just 3 final things we need to add to the picture.

1. **We use an entropy bonus to encourage policy exploration, and prevent premature convergence to suboptimal policies.**

Entropy is a very deep mathematical topic that we won't dive all the way into here, but for the sake of brevity, we can say that entropy is a measure of uncertainty - policies which will definitely take the same action in the same state have low entropy, and policies which have a wide range of likely actions have high entropy. We add some multiple of the entropy of our policy function directly onto our objective function to be maximized, with the entropy coefficient usually decaying over time as we move from explore to exploit mode.

2. **We clip the objective function $L(\theta)$ to get $L^\text{CLIP}(\theta)$, to prevent the policy from changing too much too fast.**

The clipping is applied to make sure the ratio $\pi_\theta(a_t \mid s_t) / \pi_{\theta_\text{target}}(a_t \mid s_t)$ stays close to 1, during any single learning phase (between learning phases we generate a new batch of experiences and reset $\theta_\text{target}$). Intuitively, this is because the more our policy changes from the old policy, the more unrealistic the generated experiences will be. For example, suppose we generate experiences from a bunch of chess games, where a particular class of strategies (e.g. playing more aggressively) is beneficial. We shouldn't update on these games indefinitely, because as we update and the agent's policy changes to become more aggressive, the generated experiences will no longer be accurate representations of the agent's strategy and so our objective function will no longer be a good estimate of the expected future reward.

There are various ways to keep this ratio close to 1. **Trust region policy optimization** (TRPO) explicitly adds a multiple of the [KL divergence](https://www.lesswrong.com/posts/no5jDTut5Byjqb4j5/six-and-a-half-intuitions-for-kl-divergence) to the loss function, making sure the distributions stay close. PPO however does something a lot more simple and hacky - if the ratio is larger than $1+\epsilon$ for actions with positive advantage (for some constant $\epsilon > 0$) then we clip it, preventing gradients from flowing & updating the policy network more. We do the same thing with $1-\epsilon$ when the advantage is negative.

3. **We actually use $\dfrac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{target}}(a_t \mid s_t)}$ rather than $\log \pi_\theta(a_t \mid s_t)$ when computing $L^\text{CLIP}(\theta)$.**

This is a valid thing to do precisely because we're using clipping - this allows us to assume the probability ratio is usually close to 1, and so we can use the approximation $\log(x) \approx x - 1$ for all $x \approx 1$ (with this approximation, the two loss functions are equal up to a constant that doesn't depend on $\theta$ - we leave the proof as an exercise to the reader).

Although these 3 changes are all important, it's the 2nd alteration that distinguishes PPO from other policy gradient methods. It's where it gets its name - **proximal** refers to the way in which this clipping keeps us close to the old policy.

> #### Full summary
> 
> - In **policy gradient methods**, we directly optimize the policy function $\pi_\theta(a_t \mid s_t)$ to get higher expected future rewards.
> - Our objective function is a sum of logprobs of actions taken, weighted by their advantage estimates $\hat{A}_\theta(s_t, a_t)$ (i.e. how good we think that action was), so performing gradient ascent on this leads to taking more good actions and less bad ones.
>   - Note that we could just use accumulated rewards $R(\tau)$ rather than the advantage function, but using advantage estimates is a lot more stable.
> - We have 2 networks: `actor` which learns $\pi_\theta(a_t \mid s_t)$ using this objective function, and `critic` which learns $V_\theta(s_t)$ by minimizing the TD residual (a bit like SARSA), and allows us to estimate the advantage $\hat{A}_\theta(s_t, a_t)$ which is used in the objective function.
> - We use **generalized advantage estimation** (GAE) to convert our value function estimates into advantage estimates - this mostly avoids the two possible extremes of (1) myopia from only looking at the immediate reward, and (2) instability / failure to credit single actions from looking at the entire trajectory at once.
> - On top of all this, 2 other techniques fully characterize PPO:
>   - We add an **entropy bonus** to our objective function to encourage exploration.
>   - We clip the objective function so $\pi_\theta(a_t \mid s_t)$ isn't incentivized to change too quickly (which could cause instability) - this is the "proximal" part of PPO.

<details>
<summary>Question - what is the role of the entropy term, and should it be added to or subtracted from the clipped objective function?</summary>

The entropy term encourages policy exploration. We want to add it to the objective function, because we're doing gradient ascent on it (and we want to increase exploration).

</details>

<details>
<summary>Question - how does clipping the objective function help prevent large policy updates, and why is this desireable?</summary>

Clipping prevents large policy changes by capping gradients when the update exceeds some value $\epsilon$, generally ensuring proximity to the old policy.

This is good because we don't want to change our policy too quickly, based on possibly a limited set of experiences.

</details>

## Today's setup

We've now covered all the theory we need to understand about PPO! To conclude, we'll briefly go through how our PPO algorithm is going to be set up in practice today, and relate it to what we've discussed in the previous sections.

A full diagram of our implementation is shown below:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/ppo-alg-conceptual.png" width="900">

We have 2 main phases: the **rollout phase** (where we generate a batch of experiences from our frozen network $\theta_\text{target}$) and the **learning phase** (where we update our policy $\pi_\theta$ based on the experiences generated in the rollout phase, as well as the outputs of our current network $\theta$). This is quite similar to the setup we used for DQN (where we'd alternate between generating experiences for our buffer and learning from those experiences) - the difference here is that rather than the rollout phase adding to our buffer, it'll be emptying our buffer and creating an entirely new one. So as not to be wasteful, the learning phase will iterate over our buffer multiple times before we repeat the cycle.

Just like we had `ReplayBuffer`, `DQNAgent` and `DQNTrainer` as our 3 main classes yesterday, here we have 3 main classes:

- `ReplayMemory` stores experiences generated by the agent during rollout, and has a `get_minibatches` method which samples data to be used in the learning phase
- `PPOAgent` manages the interaction between our policy and environment (particularly via the `play_step` method), it generates experiences and adds them to our memory
- `PPOTrainer` is the main class for training our model, it's essentially a wrapper around everything else

As we can see in the diagram, the learning phase has us compute an objective function which involves 3 different terms:

- The **entropy bonus** (for encouraging policy exploration) - this trains only our actor network
- The **clipped surrogate objective function** (for policy improvement) - this trains only our actor network (although it uses the critic network's estimates from $\theta_\text{target}$), and it's the most important of the three terms
- The **value function loss** (for improving our value function estimates) - this trains only our critic network

# 1️⃣ Setting up our agent

> ##### Learning Objectives
>
> - Understand the difference between the actor & critic networks, and what their roles are
> - Learn about & implement generalised advantage estimation
> - Build a replay memory to store & sample experiences
> - Design an agent class to step through the environment & record experiences

In this section, we'll do the following:

* Define a dataclass to hold our PPO arguments
* Write functions to create our actor and critic networks (which will eventually be stored in our `PPOAgent` instance)
* Write a function to do **generalized advantage estimation** (this will be necessary when computing our objective function during the learning phase)
* Fill in our `ReplayMemory` class (for storing and sampling experiences)
* Fill in our `PPOAgent` class (a wrapper around our networks and our replay memory, which will turn them into an agent)

As a reminder, we'll be continually referring back to [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool) as we go through these exercises. Most of our sections wil refer to one or more of these details.

## PPO Arguments

Just like for DQN, we've provided you with a dataclass containing arguments for your `train_ppo` function. We've also given you a function from `utils` to display all these arguments (including which ones you've changed). Lots of these are the same as for the DQN dataclass.

Don't worry if these don't all make sense right now, they will by the end.

In [3]:
@dataclass
class PPOArgs:
    # Basic / global
    seed: int = 1
    env_id: str = "CartPole-v1"
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control"

    # Wandb / logging
    use_wandb: bool = False
    video_log_freq: int | None = None
    wandb_project_name: str = "PPOCartPole"
    wandb_entity: str = None

    # Duration of different phases
    total_timesteps: int = 500_000
    num_envs: int = 4
    num_steps_per_rollout: int = 128
    num_minibatches: int = 4
    batches_per_learning_phase: int = 4

    # Optimization hyperparameters
    lr: float = 2.5e-4
    max_grad_norm: float = 0.5

    # RL hyperparameters
    gamma: float = 0.99

    # PPO-specific hyperparameters
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.25

    def __post_init__(self):
        self.batch_size = self.num_steps_per_rollout * self.num_envs

        assert self.batch_size % self.num_minibatches == 0, "batch_size must be divisible by num_minibatches"
        self.minibatch_size = self.batch_size // self.num_minibatches
        self.total_phases = self.total_timesteps // self.batch_size
        self.total_training_steps = self.total_phases * self.batches_per_learning_phase * self.num_minibatches

        self.video_save_path = section_dir / "videos"


args = PPOArgs(num_minibatches=2)  # changing this also changes minibatch_size and total_training_steps
arg_help(args)

Unnamed: 0_level_0,default value,description
arg,Unnamed: 1_level_1,Unnamed: 2_level_1
seed,1,seed of the experiment
env_id,'CartPole-v1',the id of the environment
mode,'classic-control',"can be 'classic-control', 'atari' or 'mujoco'"
use_wandb,False,"if toggled, this experiment will be tracked with Weights and Biases"
video_log_freq,,"if not None, we log videos this many episodes apart (so shorter episodes mean more frequent logging)"
wandb_project_name,'PPOCartPole',the name of this experiment (also used as the wandb project name)
wandb_entity,,the entity (team) of wandb's project
total_timesteps,500000,total timesteps of the experiments
num_envs,4,number of synchronized vector environments in our `envs` object (this is N in the '37 Implementational Details' post)
num_steps_per_rollout,128,number of steps taken in the rollout phase (this is M in the '37 Implementational Details' post)


A note on the `num_envs` argument - note that unlike yesterday, `envs` will actually have multiple instances of the environment inside (we did still have this argument yesterday but it was always set to 1). From the [37 implementation details of PPO](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=vectorized%20architecture) post:

> _In this architecture, PPO first initializes a vectorized environment `envs` that runs $N$ (usually independent) environments either sequentially or in parallel by leveraging multi-processes. `envs` presents a synchronous interface that always outputs a batch of $N$ observations from $N$ environments, and it takes a batch of $N$ actions to step the $N$ environments. When calling `next_obs = envs.reset()`, next_obs gets a batch of $N$ initial observations (pronounced "next observation"). PPO also initializes an environment `done` flag variable next_done (pronounced "next done") to an $N$-length array of zeros, where its i-th element `next_done[i]` has values of 0 or 1 which corresponds to the $i$-th sub-environment being *not done* and *done*, respectively._

## Actor-Critic Implementation ([detail #2](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Orthogonal%20Initialization%20of%20Weights%20and%20Constant%20Initialization%20of%20biases))

PPO requires two networks, an `actor` and a `critic`. The actor is the most important one; its job is to learn an optimal policy $\pi_\theta(a_t \mid s_t)$ (it does this by training on the clipped surrogate objective function, which is essentially a direct estimation of the discounted sum of future rewards with some extra bells and whistles thrown in). Estimating this also requires estimating the **advantage function** $A_\theta(s_t, a_t)$, which in requires estimating the values $V_\theta(s_t)$ - this is why we need a critic network, which learns $V_\theta(s_t)$ by minimizing the TD residual (in a similar way to how our Q-network learned the $Q(s_t, a_t)$ values).

### Exercise - implement `get_actor_and_critic`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-20 minutes on this exercise.
> ```

You should implement the `Agent` class according to the diagram. We are doing separate Actor and Critic networks because [detail #13](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Shared%20and%20separate%20MLP%20networks%20for%20policy%20and%20value%20functions) notes that is performs better than a single shared network in simple environments. Use `layer_init` to initialize each `Linear`, overriding the standard deviation argument `std` according to the diagram (when not specified, you should be using the default value of `std=np.sqrt(2)`).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/ppo_mermaid.svg" width="500">

We've given you a "high level function" `get_actor_and_critic` which calls one of three possible functions, depending on the `mode` argument. You'll implement the other two modes later. This is one way to keep our code modular.

In [4]:
def layer_init(layer: nn.Linear, std=np.sqrt(2), bias_const=0.0):
    t.nn.init.orthogonal_(layer.weight, std) # whyy orthogonal to get std?
    t.nn.init.constant_(layer.bias, bias_const) # why constant to get bias?
    return layer


def get_actor_and_critic(
    envs: gym.vector.SyncVectorEnv,
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control",
) -> tuple[nn.Module, nn.Module]:
    """
    Returns (actor, critic), the networks used for PPO, in one of 3 different modes.
    """
    assert mode in ["classic-control", "atari", "mujoco"]

    obs_shape = envs.single_observation_space.shape
    num_obs = np.array(obs_shape).prod()
    num_actions = (
        envs.single_action_space.n
        if isinstance(envs.single_action_space, gym.spaces.Discrete)
        else np.array(envs.single_action_space.shape).prod()
    )

    if mode == "classic-control":
        actor, critic = get_actor_and_critic_classic(num_obs, num_actions)
    if mode == "atari":
        actor, critic = get_actor_and_critic_atari(obs_shape, num_actions)  # you'll implement these later
    if mode == "mujoco":
        actor, critic = get_actor_and_critic_mujoco(num_obs, num_actions)  # you'll implement these later

    return actor.to(device), critic.to(device)


def get_actor_and_critic_classic(num_obs: int, num_actions: int):
    """
    Returns (actor, critic) in the "classic-control" case, according to diagram above.
    """
    # raise NotImplementedError()
    actor = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, num_actions),  std = 0.01),
    )
    
    critic = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 1), std = 1),
    )
    return actor, critic

tests.test_get_actor_and_critic(get_actor_and_critic, mode="classic-control")

All tests in `test_get_actor_and_critic(mode='classic-control')` passed!


<details>
<summary>Question - what do you think is the benefit of using a small standard deviation for the last actor layer?</summary>

The purpose is to center the initial `agent.actor` logits around zero, in other words an approximately uniform distribution over all actions independent of the state. If you didn't do this, then your agent might get locked into a nearly-deterministic policy early on and find it difficult to train away from it.

[Studies suggest](https://openreview.net/pdf?id=nIAxjsniDzg) this is one of the more important initialisation details, and performance is often harmed without it.
</details>


<details><summary>Solution</summary>

```python
def layer_init(layer: nn.Linear, std=np.sqrt(2), bias_const=0.0):
    t.nn.init.orthogonal_(layer.weight, std)
    t.nn.init.constant_(layer.bias, bias_const)
    return layer


def get_actor_and_critic(
    envs: gym.vector.SyncVectorEnv,
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control",
) -> tuple[nn.Module, nn.Module]:
    """
    Returns (actor, critic), the networks used for PPO, in one of 3 different modes.
    """
    assert mode in ["classic-control", "atari", "mujoco"]

    obs_shape = envs.single_observation_space.shape
    num_obs = np.array(obs_shape).prod()
    num_actions = (
        envs.single_action_space.n
        if isinstance(envs.single_action_space, gym.spaces.Discrete)
        else np.array(envs.single_action_space.shape).prod()
    )

    if mode == "classic-control":
        actor, critic = get_actor_and_critic_classic(num_obs, num_actions)
    if mode == "atari":
        actor, critic = get_actor_and_critic_atari(obs_shape, num_actions)  # you'll implement these later
    if mode == "mujoco":
        actor, critic = get_actor_and_critic_mujoco(num_obs, num_actions)  # you'll implement these later

    return actor.to(device), critic.to(device)


def get_actor_and_critic_classic(num_obs: int, num_actions: int):
    """
    Returns (actor, critic) in the "classic-control" case, according to diagram above.
    """
    critic = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 1), std=1.0),
    )

    actor = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, num_actions), std=0.01),
    )
    return actor, critic


tests.test_get_actor_and_critic(get_actor_and_critic, mode="classic-control")
```
</details>

## Generalized Advantage Estimation ([detail #5](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Generalized%20Advantage%20Estimation))

The **advantage function** $A_\theta(s_t, a_t)$ is defined as $Q_\theta(s_t, a_t) - V_\theta(s_t)$, i.e. it's the difference between expected future reward when we take action $a_t$ vs taking the expected action according to policy $\pi_\theta$ from that point onwards. It's an important part of our objective function, because it tells us whether we should try and take more or less of action $a_t$ in state $s_t$.

Our critic estimates $V_\theta(s_t)$, but how can we estimate the terms $Q_\theta(s_t, a_t)$ and by extension $A_\theta(s_t, a_t)$? Here are two ideas:

1. $\hat{A}_\theta(s_t, a_t) = \delta_t := r_t + V_\theta(s_{t+1}) - V_\theta(s_t)$, i.e. the 1-step residual just like we used in DQN
2. $\hat{A}_\theta(s_t, a_t) = \delta_t + \delta_{t+1} + ...$, i.e. the sum of all future residuals in subsequent states & actions in the trajectory

The problem with (1) is it's too myopic, because we're only looking 1 move ahead. If we sacrifice a chess piece to win the game 3 moves later, we want that to actually have positive advantage! (2) fixes this problem, but creates a new problem - the long time horizon estimates are very unstable, and so much of the trajectory is taken into account that it's hard to attribute any individual action. The solution is somewhere between these two - we perform **generalized advantage estimation** (GAE), which is a geometrically decaying sum of future residuals. It's controlled by the parameter $\lambda$, where $\lambda=0$ reduces to the first idea (single-step, myopic), and $\lambda=1$ to the second (full trajectories, unstable). This way, we can balance these two, and get the best of both world. 

$$
\hat{A}^{\text{GAE}(\lambda)}_t=\delta_t+(\gamma \lambda) \delta_{t+1}+\cdots+\cdots+(\gamma \lambda)^{T-t+1} \delta_{T-1} \\
$$

Note a subtlety - we need to make sure $\delta_t$ is correctly defined as $r_t - V_\theta(s_t)$ in terminating states, i.e. we don't include the $V_\theta(s_{t+1})$ term. We can actually compute the GAE estimator (taking this into account) with the following recursive formula:
$$
\hat{A}^{\text{GAE}(\lambda)}_t = \delta_t + (1 - d_{t+1}) (\gamma \lambda) \hat{A}^{\text{GAE}(\lambda)}_{t+1}
$$

<details>
<summary>Derivation (short)</summary>

If $d_{t+1}=1$ (i.e. we just terminated) then we'll get no further rewards, so our advantage on the final step of this trajectory is just $A_t = r_t - V(s_t) = \delta_t$ (since $V(s_t)$ was the estimate for future rewards before termination, and $r_t$ is the reward we actually get before terminating). So the formula above is correct for the terminal step.

Working backwards from the terminal step and applying this recursive formula, we get:

$$
\begin{aligned}
\hat{A}^{\text{GAE}(\lambda)}_{t} &= \delta_{t} \\
\hat{A}^{\text{GAE}(\lambda)}_{t-1} &= \delta_{t-1} + (\gamma \lambda) \hat{A}^{\text{GAE}(\lambda)}_{t} = \delta_{t-1} + \gamma \lambda \delta_t \\
\hat{A}^{\text{GAE}(\lambda)}_{t-2} &= \delta_{t-2} + (\gamma \lambda) \hat{A}^{\text{GAE}(\lambda)}_{t-1} = \delta_{t-2} + \gamma \lambda \left(\delta_{t-1} + (\gamma\lambda) \delta_t\right) = \delta_{t-2} + \gamma \lambda \delta_{t-1} + (\gamma\lambda)^2 \delta_t \\
&\dots
\end{aligned}
$$
and so on. This exactly matches the formula given above.

</details>

### Exercise - implement `compute_advantages`

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 20-30 minutes on this exercise.
> ```

Below, you should fill in `compute_advantages`. We recommend using a reversed for loop over $t$ to get it working, and using the recursive formula for GAE given above - don't worry about trying to vectorize it.

Tip - make sure you understand what the indices are of the tensors you've been given! The tensors `rewards`, `values` and `terminated` contain $r_t$, $V(s_t)$ and $d_t$ respectively for all $t = 0, 1, ..., T-1$, and `next_value`, `next_terminated` are the values $V(s_T)$ and $d_T$ respectively (required for the calculation of the very last advantage $A_{T-1}$).

In [23]:
@t.inference_mode()
def compute_advantages(
    next_value: Float[Tensor, "num_envs"],
    next_terminated: Bool[Tensor, "num_envs"],
    rewards: Float[Tensor, "buffer_size num_envs"],
    values: Float[Tensor, "buffer_size num_envs"],
    terminated: Bool[Tensor, "buffer_size num_envs"],
    gamma: float,
    gae_lambda: float,
) -> Float[Tensor, "buffer_size num_envs"]:
    """
    Compute advantages using Generalized Advantage Estimation.
    """

    # V_s_T, d_T, r_t, V_s_t, d_t = next_value, next_terminated.float(), rewards, values, terminated.float()
    

    # V_s_T = t.concat([V_s_t[1:], V_s_T.unsqueeze(0)], dim=0)
    # d_T = t.concat([d_t[1:], d_T.unsqueeze(0)], dim=0)

    # # initialise delta
    # deltas = r_t + gamma * V_s_T * (1.0 - d_T) - V_s_t

    # adv = t.zeros([rewards.shape[0], rewards.shape[1]])

    # # adv = t.zeros_like(deltas)
    # adv[-1] = deltas[-1]
    # for s in reversed(range(rewards.shape[0] - 1)):
    #     adv[s] = deltas[s] + gamma * gae_lambda * (1.0 - d_T[s+1])*adv[s+1]

    
    # return adv
    T = values.shape[0]
    terminated = terminated.float()
    next_terminated = next_terminated.float()

    # Get tensors of V(s_{t+1}) and d_{t+1} for all t = 0, 1, ..., T-1
    next_values = t.concat([values[1:], next_value[None, :]])
    next_terminated = t.concat([terminated[1:], next_terminated[None, :]])

    # Compute deltas: \delta_t = r_t + (1 - d_{t+1}) \gamma V(s_{t+1}) - V(s_t)
    deltas = rewards + gamma * next_values * (1.0 - next_terminated) - values

    # Compute advantages using the recursive formula, starting with advantages[T-1] = deltas[T-1] and working backwards
    advantages = t.zeros_like(deltas)
    advantages[-1] = deltas[-1]
    for s in reversed(range(T - 1)):
        advantages[s] = deltas[s] + gamma * gae_lambda * (1.0 - terminated[s + 1]) * advantages[s + 1]

    return advantages


tests.test_compute_advantages(compute_advantages)

Testing with all dones=False, single environment ... 
Testing with all dones=False, multiple environments ... 
Testing with episode termination, single environment ... 
Testing with episode termination, multiple environments ... 
All tests in `test_compute_advantages_single` passed!


<details>
<summary>Help - I get <code>RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported</code></summary>

This is probably because you're trying to perform an operation on a boolean tensor `terminated` or `next_terminated` which was designed for floats. You can fix this by casting the boolean tensor to a float tensor.

</details>


<details><summary>Solution</summary>

```python
@t.inference_mode()
def compute_advantages(
    next_value: Float[Tensor, "num_envs"],
    next_terminated: Bool[Tensor, "num_envs"],
    rewards: Float[Tensor, "buffer_size num_envs"],
    values: Float[Tensor, "buffer_size num_envs"],
    terminated: Bool[Tensor, "buffer_size num_envs"],
    gamma: float,
    gae_lambda: float,
) -> Float[Tensor, "buffer_size num_envs"]:
    """
    Compute advantages using Generalized Advantage Estimation.
    """
    T = values.shape[0]
    terminated = terminated.float()
    next_terminated = next_terminated.float()

    # Get tensors of V(s_{t+1}) and d_{t+1} for all t = 0, 1, ..., T-1
    next_values = t.concat([values[1:], next_value[None, :]])
    next_terminated = t.concat([terminated[1:], next_terminated[None, :]])

    # Compute deltas: \delta_t = r_t + (1 - d_{t+1}) \gamma V(s_{t+1}) - V(s_t)
    deltas = rewards + gamma * next_values * (1.0 - next_terminated) - values

    # Compute advantages using the recursive formula, starting with advantages[T-1] = deltas[T-1] and working backwards
    advantages = t.zeros_like(deltas)
    advantages[-1] = deltas[-1]
    for s in reversed(range(T - 1)):
        advantages[s] = deltas[s] + gamma * gae_lambda * (1.0 - terminated[s + 1]) * advantages[s + 1]

    return advantages
```
</details>

## Replay Memory

Our replay memory has some similarities to the replay buffer from yesterday, as well as some important differences.

### Sampling method

Yesterday, we continually updated our buffer and sliced off old data, and each time we called `sample` we'd take a randomly ordered subset of that data (with replacement).

With PPO, we alternate between rollout and learning phases. In rollout, we fill our replay memory entirely. In learning, we call `get_minibatches` to return the entire contents of the replay memory, but randomly shuffled and sorted into minibatches. In this way, we update on every experience, not just random samples. In fact, we'll update on each experience more than once, since we'll repeat the process of (generate minibatches, update on all of them) `batches_per_learning_phase` times during each learning phase.

### New variables

We store some of the same variables as before - $(s_t, a_t, d_t)$, but with the addition of 3 new variables: the **logprobs** $\pi(a_t\mid s_t)$, the **advantages** $A_t$ and the **returns**. Explaining these two variables and why we need them:

- `logprobs` are calculated from the logit outputs of our `actor.agent` network, corresponding to the actions $a_t$ which our agent actually chose.
    * These are necessary for calculating the clipped surrogate objective (see equation $(7)$ on page page 3 in the [PPO Algorithms paper](https://arxiv.org/pdf/1707.06347.pdf)), which as we'll see later makes sure the agent isn't rewarded for changing its policy an excessive amount.
- `advantages` are the terms $\hat{A}_t$, computed using our function `compute_advantages` from earlier.
    - Again, these are used in the calculation of the clipped surrogate objective.
- `returns` are given by the formula `returns = advantages + values` - see [detail #9](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Value%20Function%20Loss%20Clipping).
    - They are used to train the value network, in a way which is equivalent to minimizing the TD residual loss used in DQN.

Don't worry if you don't understand all of this now, we'll get to all these variables later.

### Exercise - implement `minibatch_indices`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

We'll start by implementing the `get_minibatch_indices` function, as described in [detail #6](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Mini%2Dbatch%20Updates). This will give us a list of length `num_minibatches = batch_size // minibatch_size` indices, each of length `minibatch_size`, and which collectively represent a permutation of the indices `[0, 1, ..., batch_size - 1]` where `batch_size = num_minibatches * minibatch_size`. To help visualize how this works to create our minibatches, we've included a diagram:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/ppo-buffer-sampling-4.png" width="600">

The test code below should also make it clear what your function should be returning.

In [24]:
def get_minibatch_indices(rng: Generator, batch_size: int, minibatch_size: int) -> list[np.ndarray]:
    """
    Return a list of length `num_minibatches`, where each element is an array of `minibatch_size` and the union of all
    the arrays is the set of indices [0, 1, ..., batch_size - 1] where `batch_size = num_steps_per_rollout * num_envs`.
    """
    assert batch_size % minibatch_size == 0
    # raise NotImplementedError()
    num_minibatches = batch_size // minibatch_size # 2
    # print(rng.integers(0, batch_size, batch_size).reshape(num_minibatches, minibatch_size))
    # print(rng.permutation(batch_size).reshape(num_minibatches, minibatch_size))
    return list(rng.permutation(batch_size).reshape(num_minibatches, minibatch_size))

rng = np.random.default_rng(0)

batch_size = 12
minibatch_size = 6
# num_minibatches = batch_size // minibatch_size = 2

indices = get_minibatch_indices(rng, batch_size, minibatch_size)

assert isinstance(indices, list)
assert all(isinstance(x, np.ndarray) for x in indices)
assert np.array(indices).shape == (2, 6)
assert sorted(np.unique(indices)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
print("All tests in `test_minibatch_indexes` passed!")

All tests in `test_minibatch_indexes` passed!


<details><summary>Solution</summary>

```python
def get_minibatch_indices(rng: Generator, batch_size: int, minibatch_size: int) -> list[np.ndarray]:
    """
    Return a list of length `num_minibatches`, where each element is an array of `minibatch_size` and the union of all
    the arrays is the set of indices [0, 1, ..., batch_size - 1] where `batch_size = num_steps_per_rollout * num_envs`.
    """
    assert batch_size % minibatch_size == 0
    num_minibatches = batch_size // minibatch_size
    indices = rng.permutation(batch_size).reshape(num_minibatches, minibatch_size)
    return list(indices)
```
</details>

### `ReplayMemory` class

Next, we've given you the `ReplayMemory` class. This follows a very similar structure to the DQN equivalent `ReplayBuffer` yesterday, with a bit of added complexity. We'll highlight the key differences below:

- There's no `[-self.buffer_size:]` slicing like there was in the DQN buffer yesterday. That's because rather than continually adding to our buffer and removing the oldest data, we'll iterate through a process of (fill entire memory, generate a bunch of minibatches from that memory and train on them, empty the memory, repeat).
- The `get_minibatches` method computes the advantages and returns. This isn't really in line with the SoC (separation of concerns) principle, but this is the easiest place to compute them because we can't do it after we sample the minibatches.
- A single learning phase involves creating `num_minibatches = batch_size // minibatch_size` minibatches and training on each of them, and then repeating this process `batches_per_learning_phase` times. So the total number of minibaches per learning phase is `batches_per_learning_phase * num_minibatches`.

<details>
<summary>Question - can you see why <code>advantages</code> can't be computed after we sample minibatches?</summary>

The samples are not in chronological order, they're shuffled. The formula for computing advantages required the data to be in chronological order.

</details>

In [25]:
@dataclass
class ReplayMinibatch:
    """
    Samples from the replay memory, converted to PyTorch for use in neural network training.

    Data is equivalent to (s_t, a_t, logpi(a_t|s_t), A_t, A_t + V(s_t), d_{t+1})
    """

    obs: Float[Tensor, "minibatch_size *obs_shape"]
    actions: Int[Tensor, "minibatch_size *action_shape"]
    logprobs: Float[Tensor, "minibatch_size"]
    advantages: Float[Tensor, "minibatch_size"]
    returns: Float[Tensor, "minibatch_size"]
    terminated: Bool[Tensor, "minibatch_size"]


class ReplayMemory:
    """
    Contains buffer; has a method to sample from it to return a ReplayMinibatch object.
    """

    rng: Generator
    obs: Float[Arr, "buffer_size num_envs *obs_shape"]
    actions: Int[Arr, "buffer_size num_envs *action_shape"]
    logprobs: Float[Arr, "buffer_size num_envs"]
    values: Float[Arr, "buffer_size num_envs"]
    rewards: Float[Arr, "buffer_size num_envs"]
    terminated: Bool[Arr, "buffer_size num_envs"]

    def __init__(
        self,
        num_envs: int,
        obs_shape: tuple,
        action_shape: tuple,
        batch_size: int,
        minibatch_size: int,
        batches_per_learning_phase: int,
        seed: int = 42,
    ):
        self.num_envs = num_envs
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.batch_size = batch_size
        self.minibatch_size = minibatch_size
        self.batches_per_learning_phase = batches_per_learning_phase
        self.rng = np.random.default_rng(seed)
        self.reset()

    def reset(self):
        """Resets all stored experiences, ready for new ones to be added to memory."""
        self.obs = np.empty((0, self.num_envs, *self.obs_shape), dtype=np.float32)
        self.actions = np.empty((0, self.num_envs, *self.action_shape), dtype=np.int32)
        self.logprobs = np.empty((0, self.num_envs), dtype=np.float32)
        self.values = np.empty((0, self.num_envs), dtype=np.float32)
        self.rewards = np.empty((0, self.num_envs), dtype=np.float32)
        self.terminated = np.empty((0, self.num_envs), dtype=bool)

    def add(
        self,
        obs: Float[Arr, "num_envs *obs_shape"],
        actions: Int[Arr, "num_envs *action_shape"],
        logprobs: Float[Arr, "num_envs"],
        values: Float[Arr, "num_envs"],
        rewards: Float[Arr, "num_envs"],
        terminated: Bool[Arr, "num_envs"],
    ) -> None:
        """Add a batch of transitions to the replay memory."""
        # Check shapes & datatypes
        for data, expected_shape in zip(
            [obs, actions, logprobs, values, rewards, terminated], [self.obs_shape, self.action_shape, (), (), (), ()]
        ):
            assert isinstance(data, np.ndarray)
            assert data.shape == (self.num_envs, *expected_shape)

        # Add data to buffer (not slicing off old elements)
        self.obs = np.concatenate((self.obs, obs[None, :]))
        self.actions = np.concatenate((self.actions, actions[None, :]))
        self.logprobs = np.concatenate((self.logprobs, logprobs[None, :]))
        self.values = np.concatenate((self.values, values[None, :]))
        self.rewards = np.concatenate((self.rewards, rewards[None, :]))
        self.terminated = np.concatenate((self.terminated, terminated[None, :]))

    def get_minibatches(
        self, next_value: Tensor, next_terminated: Tensor, gamma: float, gae_lambda: float
    ) -> list[ReplayMinibatch]:
        """
        Returns a list of minibatches. Each minibatch has size `minibatch_size`, and the union over all minibatches is
        `batches_per_learning_phase` copies of the entire replay memory.
        """
        # Convert everything to tensors on the correct device
        obs, actions, logprobs, values, rewards, terminated = (
            t.tensor(x, device=device)
            for x in [self.obs, self.actions, self.logprobs, self.values, self.rewards, self.terminated]
        )

        # Compute advantages & returns
        advantages = compute_advantages(next_value, next_terminated, rewards, values, terminated, gamma, gae_lambda)
        returns = advantages + values

        # Return a list of minibatches
        minibatches = []
        for _ in range(self.batches_per_learning_phase):
            for indices in get_minibatch_indices(self.rng, self.batch_size, self.minibatch_size):
                minibatches.append(
                    ReplayMinibatch(
                        *[
                            data.flatten(0, 1)[indices]
                            for data in [obs, actions, logprobs, advantages, returns, terminated]
                        ]
                    )
                )

        # Reset memory (since we only need to call this method once per learning phase)
        self.reset()

        return minibatches

Like before, here's some code to generate and plot observations.

The first plot shows the current observations $s_t$ (with dotted lines indicating a terminated episode $d_{t+1} = 1$). The solid lines indicate the transition between different environments in `envs` (because unlike yesterday, we're actually using more than one environment in our `SyncVectorEnv`). There are `batch_size = num_steps_per_rollout * num_envs = 128 * 2 = 256` observations in total, with `128` coming from each environment.

The second plot shows a single minibatch of sampled experiences from full memory. Each minibatch has size `minibatch_size = 128`, and `minibatches` contains in total `batches_per_learning_phase * (batch_size // minibatch_size) = 2 * 2 = 4` minibatches.

Note that we don't need to worry about terminal observations here, because we're not actually logging `next_obs` (unlike DQN, this won't be part of our loss function).

In [26]:
num_steps_per_rollout = 128
num_envs = 2
batch_size = num_steps_per_rollout * num_envs  # 256

minibatch_size = 128
num_minibatches = batch_size // minibatch_size  # 2

batches_per_learning_phase = 2

envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", i, i, "test") for i in range(num_envs)])
memory = ReplayMemory(num_envs, (4,), (), batch_size, minibatch_size, batches_per_learning_phase)

logprobs = values = np.zeros(envs.num_envs)  # dummy values, just so we can see demo of plot
obs, _ = envs.reset()

for i in range(args.num_steps_per_rollout):
    # Choose random action, and take a step in the environment
    actions = envs.action_space.sample()
    next_obs, rewards, terminated, truncated, infos = envs.step(actions)

    # Add experience to memory
    memory.add(obs, actions, logprobs, values, rewards, terminated)
    obs = next_obs

plot_cartpole_obs_and_dones(
    memory.obs,
    memory.terminated,
    title="Current obs s<sub>t</sub><br>Dotted lines indicate d<sub>t+1</sub> = 1, solid lines are environment separators",
)

next_value = next_done = t.zeros(envs.num_envs).to(device)  # dummy values, just so we can see demo of plot
minibatches = memory.get_minibatches(next_value, next_done, gamma=0.99, gae_lambda=0.95)

plot_cartpole_obs_and_dones(
    minibatches[0].obs.cpu(),
    minibatches[0].terminated.cpu(),
    title="Current obs (sampled)<br>this is what gets fed into our model for training",
)

## PPO Agent

As the final task in this section, you should fill in the agent's `play_step` method. This is conceptually similar to what you did during DQN, but with a few key differences:

- In DQN we selected actions based on our Q-network & an epsilon-greedy policy, but instead your actions will be generated directly from your actor network
- Here, you'll have to compute the extra data `logprobs` and `values`, which we didn't have to deal with in DQN

### Exercise - implement `PPOAgent`

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵🔵⚪
> 
> You should spend up to 20-40 minutes on this exercise.
> ```

A few tips:

- When sampling actions (and calculating logprobs), you might find `torch.distributions.categorical.Categorical` useful. If `logits` is a 2D tensor of shape `(N, k)` containing a batch of logit vectors and `dist = Categorical(logits=logits)`, then:
    - `actions = dist.sample()` will give you a vector of `N` sampled actions (which will be integers in the range `[0, k)`),
    - `logprobs = dist.log_prob(actions)` will give you a vector of the `N` logprobs corresponding to the sampled actions
- Make sure to use inference mode when using `obs` to compute `logits` and `values`, since all you're doing here is getting experiences for your memory - you aren't doing gradient descent based on these values.
- Check the shape of your arrays when adding them to memory (the `add` method has lots of `assert` statements here to help you), and also make sure that they are arrays not tensors by calling `.cpu().numpy()` on them.
- Remember to update `self.next_obs` and `self.next_terminated` at the end of the function!

In [50]:
class PPOAgent:
    critic: nn.Sequential
    actor: nn.Sequential

    def __init__(self, envs: gym.vector.SyncVectorEnv, actor: nn.Module, critic: nn.Module, memory: ReplayMemory):
        super().__init__()
        self.envs = envs
        self.actor = actor
        self.critic = critic
        self.memory = memory # what's in memory? obs, actions, log_probs, values, rewards, terminated

        self.step = 0  # Tracking number of steps taken (across all environments)
        self.next_obs = t.tensor(envs.reset()[0], device=device, dtype=t.float)  # need starting obs (in tensor form)
        self.next_terminated = t.zeros(envs.num_envs, device=device, dtype=t.bool)  # need starting termination=False

    def play_step(self) -> list[dict]:
        """
        Carries out a single interaction step between the agent and the environment, and adds results to the replay memory.

        Returns the list of info dicts returned from `self.envs.step`.
        """
        # Get newest observations (i.e. where we're starting from)
        obs = self.next_obs # tensor
        terminated = self.next_terminated
        
        # raise NotImplementedError()
        # print(self.memory)
        with t.inference_mode():
            logits = self.actor(obs)
        dist = Categorical(logits = logits)
        actions = dist.sample() # shape N
        # step with the new action
        # print(self.envs.step(actions.cpu().numpy()).shape)
        next_obs, rewards, next_terminated, next_truncated, infos  = self.envs.step(actions.cpu().numpy()) # return 5 things
        # This gives info on next observations, rewards, if the actions terminate at the next step, if the actions get truncated in at the next step
        # all numpy arrays

        # calculate logprobs and values and add to replay
        logprobs = dist.log_prob(actions)
        with t.inference_mode():
            values = self.critic(obs).flatten().cpu().numpy()
        self.memory.add(obs.cpu().numpy(), actions.cpu().numpy(), logprobs.cpu().numpy(), values, rewards, terminated.cpu().numpy())

        # update next obs and terminated 
        self.next_obs = t.from_numpy(next_obs).to(device, dtype=t.float)
        self.next_terminated = t.from_numpy(next_terminated).to(device, dtype=t.float)
        self.step += self.envs.num_envs
        return infos

    def get_minibatches(self, gamma: float, gae_lambda: float) -> list[ReplayMinibatch]:
        """
        Gets minibatches from the replay memory, and resets the memory
        """
        with t.inference_mode():
            next_value = self.critic(self.next_obs).flatten()
        minibatches = self.memory.get_minibatches(next_value, self.next_terminated, gamma, gae_lambda)
        self.memory.reset()
        return minibatches


tests.test_ppo_agent(PPOAgent)

All tests in `test_agent` passed!


<details><summary>Solution</summary>

```python
class PPOAgent:
    critic: nn.Sequential
    actor: nn.Sequential

    def __init__(self, envs: gym.vector.SyncVectorEnv, actor: nn.Module, critic: nn.Module, memory: ReplayMemory):
        super().__init__()
        self.envs = envs
        self.actor = actor
        self.critic = critic
        self.memory = memory

        self.step = 0  # Tracking number of steps taken (across all environments)
        self.next_obs = t.tensor(envs.reset()[0], device=device, dtype=t.float)  # need starting obs (in tensor form)
        self.next_terminated = t.zeros(envs.num_envs, device=device, dtype=t.bool)  # need starting termination=False

    def play_step(self) -> list[dict]:
        """
        Carries out a single interaction step between the agent and the environment, and adds results to the replay memory.

        Returns the list of info dicts returned from `self.envs.step`.
        """
        # Get newest observations (i.e. where we're starting from)
        obs = self.next_obs
        terminated = self.next_terminated

        # Compute logits based on newest observation, and use it to get an action distribution we sample from
        with t.inference_mode():
            logits = self.actor(obs)
        dist = Categorical(logits=logits)
        actions = dist.sample()

        # Step environment based on the sampled action
        next_obs, rewards, next_terminated, next_truncated, infos = self.envs.step(actions.cpu().numpy())

        # Calculate logprobs and values, and add this all to replay memory
        logprobs = dist.log_prob(actions).cpu().numpy()
        with t.inference_mode():
            values = self.critic(obs).flatten().cpu().numpy()
        self.memory.add(obs.cpu().numpy(), actions.cpu().numpy(), logprobs, values, rewards, terminated.cpu().numpy())

        # Set next observation & termination state
        self.next_obs = t.from_numpy(next_obs).to(device, dtype=t.float)
        self.next_terminated = t.from_numpy(next_terminated).to(device, dtype=t.float)

        self.step += self.envs.num_envs
        return infos

    def get_minibatches(self, gamma: float, gae_lambda: float) -> list[ReplayMinibatch]:
        """
        Gets minibatches from the replay memory, and resets the memory
        """
        with t.inference_mode():
            next_value = self.critic(self.next_obs).flatten()
        minibatches = self.memory.get_minibatches(next_value, self.next_terminated, gamma, gae_lambda)
        self.memory.reset()
        return minibatches
```
</details>

# 2️⃣ Learning Phase

> ##### Learning Objectives
>
> - Implement the total objective function (sum of three separate terms)
> - Understand the importance of each of these terms for the overall algorithm
> - Write a function to return an optimizer and learning rate scheduler for your model

In the last section, we wrote a lot of setup code (including handling most of how our rollout phase will work). Next, we'll turn to the learning phase.

In the next exercises, you'll write code to compute your total objective function. This is given by equation $(9)$ in the paper, and is the sum of three terms - we'll implement each one individually.

Note - the convention we've used in these exercises for signs is that **your function outputs should be the expressions in equation $(9)$**, in other words you will compute $L_t^{CLIP}(\theta)$, $c_1 L_t^{VF}(\theta)$ and $c_2 S[\pi_\theta](s_t)$. We will then perform **gradient ascent** by passing `maximize=True` into our optimizers. An equally valid solution would be to just return the negative of the objective function.

## Objective function

### Clipped Surrogate Objective ([detail #8](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Clipped%20surrogate%20objective,-(ppo2/model)))

For each minibatch, calculate $L^{CLIP}$ from equation $(7)$ in [the paper](https://arxiv.org/abs/1707.06347). This will allow us to improve the parameters of our actor.
$$
L^{CLIP}(\theta) = \frac{1}{|B|} \sum_t \left[\min \left(r_{t}(\theta) \hat{A}_t, \operatorname{clip}\left(r_{t}(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\right)\right]
$$
Where we define $r_t(\theta) = \dfrac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{target}}}(a_t|s_t)}$ as the probability ratio between the current policy probability and the old policy probability (don't confuse this with reward $r_t$!). 

If you want to recap the derivation of this expression you can use the dropdown below or go back to section 0️⃣, but the key intuition is as follows:

- Ignoring clipping, this objective function will lead to a positive gradient for $\pi_\theta(a_t \mid s_t)$ when the action $a_t$ had positive advantage - in other words, **take more good actions and less bad actions!**
- The clipping formula looks complicated, but actually all it's saying is "if the probability ratio goes outside the bounds $[1-\epsilon, 1+\epsilon]$ then we no longer apply gradients to it". This is good because it means if the policy changes signficantly, we don't overupdate on experiences that weren't generated by this new changed policy. This is where the "proximal" part of PPO comes from.

<details>
<summary>Click here to see the full derivation</summary>

To fully explain this function, we start with the non-proximal version which doesn't involve clipping:
$$
L(\theta) = \frac{1}{|B|} \sum_t \left(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{target}}}(a_t|s_t)} \hat{A}_t \right)
$$

If we replace the probability ratio here with the log probability $\log \pi_\theta(a_t | s_t)$ then we get something that looks a lot more like the loss function we saw in section 0️⃣, i.e. the one whose gradient $\nabla_\theta L(\theta)$ we proved to be equal to the gradient of the sum of expected returns $\nabla_\theta J(\pi_\theta)$ (see section 0️⃣ for proof). The reason we use the probability ratio here rather than the log probability is for stability - as long as $\pi_\theta$ and $\pi_{\theta_\text{target}}$ are sufficiently close, this ratio will be close to 1, and so we can use the approximation $\log x \approx x - 1$ for $x$ close to 1. With this approximation, we get:

$$
\log(\pi_\theta(a_t \mid s_t)) - \log(\pi_{\theta_\text{target}}(a_t \mid s_t)) = \log\left(\frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{target}}(a_t \mid s_t)}\right) \approx \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{target}}(a_t \mid s_t)} - 1
$$

Rearranging, we get $\dfrac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{target}}(a_t \mid s_t)} \approx \log(\pi_\theta(a_t \mid s_t)) + \text{const}$, where $\text{const}$ is independent of $\theta$.

This means the substitution we made above is valid.

Why can we assume $x$ is close to 1? That brings us to the second part of the objective function, the clipping. The $\text{min}$ and $\text{clip}$ in the expression might look a bit messy and confusing at first, but it's a lot simpler than it looks - in practice all this means is that we clip the probability ratio when it's larger than $1+\epsilon$ for positive advantages, or less than $1-\epsilon$ for negative advantages. For example, if $\hat{A}_t > 0$ then $(1-\epsilon) \hat{A}_t$ will always be less than $r_t(\theta) \hat{A}_t$, and so the expression reduces to $\text{min}\left(r_t(\theta) \hat{A}_t, (1+\epsilon) \hat{A}_t\right) = \text{min}(r_t(\theta), 1+\epsilon) \hat{A}_t$. The illustration below might help.

</details>

You might find the illustration below more helpful for explaining how the clipping works:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/advantages.png" width="800">

### Exercise - write `calc_clipped_surrogate_objective`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵🔵⚪
> 
> You should spend up to 10-25 minutes on this exercise.
> ```

Implement the function below. The current probabilities $\pi_\theta(a_t \mid s_t)$ come from the actions `mb_actions` evaluated using the newer `probs`, and the old probabilities $\pi_{\theta_\text{target}}(a_t \mid s_t)$ come from the stored `mb_logprobs`.

A few things to note:

- You should pay attention to the normalization instructions in [detail #7](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Normalization%20of%20Advantages) when implementing this loss function. They add a value of `eps = 1e-8` to the denominator to avoid division by zero, you should also do this.
- You can use the `probs.log_prob` method to get the log probabilities that correspond to the actions in `mb_action`.
    - If you're wondering why we're using a `Categorical` type rather than just using `log_prob` directly, it's because we'll be using them to sample actions later on in our `train_ppo` function. Also, categoricals have a useful method for returning the entropy of a distribution (which will be useful for the entropy term in the loss function).
- Our `mb_action` has shape `(minibatch_size, *action_shape)`, but in most of the environments we're dealing with (CartPole, and later the Breakout Atari env) the action shape is an empty tuple, which is why we have the assert statement at the start of this function.
- The clip formula can be a bit finnicky (i.e. when you take signs and max/mins), we recommend breaking the computation onto a few separate lines rather than doing it all in one go!

In [65]:
def calc_clipped_surrogate_objective(
    probs: Categorical,
    mb_action: Int[Tensor, "minibatch_size"],
    mb_advantages: Float[Tensor, "minibatch_size"],
    mb_logprobs: Float[Tensor, "minibatch_size"],
    clip_coef: float,
    eps: float = 1e-8,
) -> Float[Tensor, ""]:
    """Return the clipped surrogate objective, suitable for maximisation with gradient ascent.

    probs:
        a distribution containing the actor's unnormalized logits of shape (minibatch_size, num_actions)
    mb_action:
        what actions actions were taken in the sampled minibatch
    mb_advantages:
        advantages calculated from the sampled minibatch
    mb_logprobs:
        logprobs of the actions taken in the sampled minibatch (according to the old policy)
    clip_coef:
        amount of clipping, denoted by epsilon in Eq 7.
    eps:
        used to add to std dev of mb_advantages when normalizing (to avoid dividing by zero)
    """
    assert mb_action.shape == mb_advantages.shape == mb_logprobs.shape
    # raise NotImplementedError()
    action_probs = probs.log_prob(mb_action)
    probs_diff = action_probs- mb_logprobs
    probs_diff = t.exp(probs_diff)

    # normalisation
    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    unclipped = mb_advantages * probs_diff
    clipped = t.clip(probs_diff, 1 - clip_coef, 1 + clip_coef) * mb_advantages
    print(t.min(unclipped, clipped).mean().shape)
    return t.min(unclipped, clipped).mean()
    

tests.test_calc_clipped_surrogate_objective(calc_clipped_surrogate_objective)

torch.Size([])
All tests in `test_calc_clipped_surrogate_objective` passed.


<details><summary>Solution</summary>

```python
def calc_clipped_surrogate_objective(
    probs: Categorical,
    mb_action: Int[Tensor, "minibatch_size"],
    mb_advantages: Float[Tensor, "minibatch_size"],
    mb_logprobs: Float[Tensor, "minibatch_size"],
    clip_coef: float,
    eps: float = 1e-8,
) -> Float[Tensor, ""]:
    """Return the clipped surrogate objective, suitable for maximisation with gradient ascent.

    probs:
        a distribution containing the actor's unnormalized logits of shape (minibatch_size, num_actions)
    mb_action:
        what actions actions were taken in the sampled minibatch
    mb_advantages:
        advantages calculated from the sampled minibatch
    mb_logprobs:
        logprobs of the actions taken in the sampled minibatch (according to the old policy)
    clip_coef:
        amount of clipping, denoted by epsilon in Eq 7.
    eps:
        used to add to std dev of mb_advantages when normalizing (to avoid dividing by zero)
    """
    assert mb_action.shape == mb_advantages.shape == mb_logprobs.shape
    logits_diff = probs.log_prob(mb_action) - mb_logprobs

    prob_ratio = t.exp(logits_diff)

    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    non_clipped = prob_ratio * mb_advantages
    clipped = t.clip(prob_ratio, 1 - clip_coef, 1 + clip_coef) * mb_advantages

    return t.minimum(non_clipped, clipped).mean()


tests.test_calc_clipped_surrogate_objective(calc_clipped_surrogate_objective)
```
</details>

### Value Function Loss ([detail #9](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Value%20Function%20Loss%20Clipping))

The value function loss lets us improve the parameters of our critic. Today we're going to implement the simple form: this is just the mean squared difference between the following two terms:

* The **critic's prediction** - this is $V_\theta(s_t)$ in the paper, and `values` in our code (i.e. values computed from the updated critic network)
* The **observed returns** - this is $V_t^\text{targ}$ in the paper, and `returns = memory.advantages + memory.values` in our code (i.e. values generated during rollout)

Note that the observed returns are equivalent to the observed next-step Q values, meaning the squared difference between these two is the TD error. It's analogous to the TD error we used during SARSA - the purpose is to bring $V_\theta(s_t)$ closer in line with the actual value estimated with the benefit of seeing more future actions. Remember that because we're using GAE, the observed returns $V_t^\text{target}$ don't just take into account the next action, but many future actions in the trajectory.

*Note - the PPO paper did a more complicated thing with clipping, but we're going to deviate from the paper and NOT clip, since [detail #9](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Value%20Function%20Loss%20Clipping) gives evidence that it isn't beneficial.*

### Exercise - write `calc_value_function_loss`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 5-10 minutes on this exercise.
> ```

Implement `calc_value_function_loss` which returns the term denoted $c_1 L_t^{VF}$ in equation $(9)$.

In [None]:
def calc_value_function_loss(
    values: Float[Tensor, "minibatch_size"], mb_returns: Float[Tensor, "minibatch_size"], vf_coef: float
) -> Float[Tensor, ""]:
    """Compute the value function portion of the loss function.

    values:
        the value function predictions for the sampled minibatch (using the updated critic network)
    mb_returns:
        the target for our updated critic network (computed as `advantages + values` from the old network)
    vf_coef:
        the coefficient for the value loss, which weights its contribution to the overall loss. Denoted by c_1 in the paper.
    """
    assert values.shape == mb_returns.shape
    print(values.shape)
    # raise NotImplementedError()
    print(((values- mb_returns)**2).mean()*vf_coef)
    return ((values- mb_returns)**2).mean()*vf_coef # why do we mean over both dimensions?: the loss reflects the average error per sample


tests.test_calc_value_function_loss(calc_value_function_loss)

torch.Size([5, 4])
tensor(0.7988)
All tests in `test_calc_value_function_loss` passed!


<details><summary>Solution</summary>

```python
def calc_value_function_loss(
    values: Float[Tensor, "minibatch_size"], mb_returns: Float[Tensor, "minibatch_size"], vf_coef: float
) -> Float[Tensor, ""]:
    """Compute the value function portion of the loss function.

    values:
        the value function predictions for the sampled minibatch (using the updated critic network)
    mb_returns:
        the target for our updated critic network (computed as `advantages + values` from the old network)
    vf_coef:
        the coefficient for the value loss, which weights its contribution to the overall loss. Denoted by c_1 in the paper.
    """
    assert values.shape == mb_returns.shape

    return vf_coef * (values - mb_returns).pow(2).mean()
```
</details>

### Entropy Bonus ([detail #10](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Overall%20Loss%20and%20Entropy%20Bonus))

The entropy bonus term is intended to incentivize exploration by increasing the entropy of the actions distribution. For a discrete probability distribution $p$, the entropy $H$ is defined as
$$
H(p) = \sum_x p(x) \ln \frac{1}{p(x)}
$$
If $p(x) = 0$, then we define $0 \ln \frac{1}{0} := 0$ (by taking the limit as $p(x) \to 0$).
You should understand what entropy of a discrete distribution means, but you don't have to implement it yourself: `probs.entropy` computes it using the above formula but in a numerically stable way, and in
a way that handles the case where $p(x) = 0$.

Question: in CartPole, what are the minimum and maximum values that entropy can take? What behaviors correspond to each of these cases?

<details>
<summary>Answer</summary>

The minimum entropy is zero, under the policy "always move left" or "always move right".

The maximum entropy is $\ln(2) \approx 0.693$ under the uniform random policy over the 2 actions.
</details>

Separately from its role in the loss function, the entropy of our action distribution is a useful diagnostic to have: if the entropy of agent's actions is near the maximum, it's playing nearly randomly which means it isn't learning anything (assuming the optimal policy isn't random). If it is near the minimum especially early in training, then the agent might not be exploring enough.

### Exercise - write `calc_entropy_bonus`

> ```yaml
> Difficulty: 🔴⚪⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to ~10 minutes on this exercise.
> ```

In [76]:
def calc_entropy_bonus(dist: Categorical, ent_coef: float):
    """Return the entropy bonus term, suitable for gradient ascent.

    dist:
        the probability distribution for the current policy
    ent_coef:
        the coefficient for the entropy loss, which weights its contribution to the overall objective function. Denoted by c_2 in the paper.
    """
    # raise NotImplementedError()
    return ent_coef * dist.entropy().mean()


tests.test_calc_entropy_bonus(calc_entropy_bonus)

All tests in `test_calc_entropy_bonus` passed!


<details><summary>Solution</summary>

```python
def calc_entropy_bonus(dist: Categorical, ent_coef: float):
    """Return the entropy bonus term, suitable for gradient ascent.

    dist:
        the probability distribution for the current policy
    ent_coef:
        the coefficient for the entropy loss, which weights its contribution to the overall objective function. Denoted by c_2 in the paper.
    """
    return ent_coef * dist.entropy().mean()
```
</details>

## Adam Optimizer & Scheduler (details [#3](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=The%20Adam%20Optimizer%E2%80%99s%20Epsilon%20Parameter) & [#4](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Adam%20Learning%20Rate%20Annealing))

Even though Adam is already an adaptive learning rate optimizer, empirically it's still beneficial to decay the learning rate.

### Exercise - implement `PPOScheduler`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

Implement a linear decay from `initial_lr` to `end_lr` over `total_training_steps` steps. Also, make sure you read details details [#3](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=The%20Adam%20Optimizer%E2%80%99s%20Epsilon%20Parameter) and [#4](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Adam%20Learning%20Rate%20Annealing) so you don't miss any of the Adam implementational details. Note, the training terminates after `num_updates`, so you don't need to worry about what the learning rate will be after this point.

Recall from our look at optimizers in the first week: we edit hyperparameters like learning rate of an optimizer by iterating through `optimizer.param_groups` and setting the `param_group["lr"]` attribute.

We've implemented the `make_optimizer` function for you. Note that we've passed in `itertools.chain` to the optimizer - this creates a single iterable out of multiple iterables, and is necessary because the optimizer expects a single iterable. Another option would just be to convert the parameters to lists and concatenate them.

In [78]:
class PPOScheduler:
    def __init__(self, optimizer: Optimizer, initial_lr: float, end_lr: float, total_phases: int):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.end_lr = end_lr
        self.total_phases = total_phases
        self.n_step_calls = 0

    def step(self):
        """Implement linear learning rate decay so that after `total_phases` calls to step, the learning rate is end_lr.

        Do this by directly editing the learning rates inside each param group (i.e. `param_group["lr"] = ...`), for each param
        group in `self.optimizer.param_groups`.
        """
        # raise NotImplementedError()
        self.n_step_calls += 1
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.initial_lr + (self.end_lr - self.initial_lr) * self.n_step_calls / self.total_phases


def make_optimizer(
    actor: nn.Module, critic: nn.Module, total_phases: int, initial_lr: float, end_lr: float = 0.0
) -> tuple[optim.Adam, PPOScheduler]:
    """
    Return an appropriately configured Adam with its attached scheduler.
    """
    optimizer = optim.AdamW(
        itertools.chain(actor.parameters(), critic.parameters()), lr=initial_lr, eps=1e-5, maximize=True
    )
    scheduler = PPOScheduler(optimizer, initial_lr, end_lr, total_phases)
    return optimizer, scheduler


tests.test_ppo_scheduler(PPOScheduler)

All tests in `test_ppo_scheduler` passed!


<details><summary>Solution</summary>

```python
class PPOScheduler:
    def __init__(self, optimizer: Optimizer, initial_lr: float, end_lr: float, total_phases: int):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.end_lr = end_lr
        self.total_phases = total_phases
        self.n_step_calls = 0

    def step(self):
        """Implement linear learning rate decay so that after `total_phases` calls to step, the learning rate is end_lr.

        Do this by directly editing the learning rates inside each param group (i.e. `param_group["lr"] = ...`), for each param
        group in `self.optimizer.param_groups`.
        """
        self.n_step_calls += 1
        frac = self.n_step_calls / self.total_phases
        assert frac <= 1
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = self.initial_lr + frac * (self.end_lr - self.initial_lr)
```
</details>

# 3️⃣ Training Loop

> ##### Learning Objectives
>
> - Build a full training loop for the PPO algorithm
> - Train our agent, and visualise its performance with Weights & Biases media logger
> - Use reward shaping to improve your agent's training (and make it do tricks!)

## Writing your training loop

Finally, we can package this all together into our full training loop. The `train` function has been written for you: it just performs an alternating sequence of rollout & learning phases, a total of `args.total_phases` times each. You can see in the `__post_init__` method of our dataclass how this value was calculated by dividing the total agent steps by the batch size (which is the number of agent steps required per rollout phase).

Your job will be to fill in the logic for the rollout & learning phases. This will involve using many of the functions you've written in the last 2 sections.

### Exercise - complete the `PPOTrainer` class

> ```yaml
> Difficulty: 🔴🔴🔴🔴🔴
> Importance: 🔵🔵🔵🔵🔵
> 
> You should spend up to 30-60 minutes on this exercise (including logging).
> It will be the hardest exercise today.
> ```

You should fill in the following methods. Ignoring logging, they should do the following:

- `rollout_phase`
    - Step the agent through the environment for `num_steps_per_rollout` total steps, which collects `num_steps_per_rollout * num_envs` experiences into the replay memory
    - This will be near identical to yesterday's `add_to_replay_buffer` method
- `learning_phase`
    - Sample from the replay memory using `agent.get_minibatches` (which returns a list of minibatches), this automatically resets the memory
    - Iterate over these minibatches, and for each minibatch you should backprop wrt the objective function computed from the `compute_ppo_objective` method
    - Note that after each `backward()` call, you should also **clip the gradients** in accordance with [detail #11](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Global%20Gradient%20Clipping%20)
        - You can use `nn.utils.clip_grad_norm(parameters, max_grad_norm)` for this - see [documentation page](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html). The `args` dataclass contains the max norm for clipping gradients
    - Also remember to step the optimizer _and_ scheduler at the end of the method
        - The optimizer should be stepped once per minibatch, but the scheduler should just be stepped **once per learning phase** (in classic ML, we generally step schedulers once per epoch)
- `compute_ppo_objective`
    - Handles actual computation of the PPO objective function
    - Note that you'll need to compute `logits` and `values` from the minibatch observation `minibatch.obs`, but unlike in our previous functions **this shouldn't be done in inference mode**, since these are actually the values that propagate gradients!
    - Also remember to get the sign correct - our optimizer was set up for **gradient ascent**, so we should return `total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus` from this method

Once you get this working, you should also add logging:

- Log the data for any terminated episodes in `rollout_phase`
    - This should be the same as yesterday's exercise, in fact you can use the same `get_episode_data_from_infos` helper function (we've imported it for you at the top of this file)
- Log useful data related to your different objective function components in `compute_ppo_objective`
    - Some recommendations for what to log can be found in [detail #12](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Debug%20variables)
    
We recommend not focusing too much on wandb & logging initially, just like yesterday. Once again you have the probe environments to test your code, and even after that point you'll get better feedback loops by turning off wandb until you're more confident in your solution. The most important thing to log is the episode length & reward in `rollout_phase`, and if you have this appearing on your progress bar then you'll be able to get a good sense of how your agent is doing. Even without this and without wandb, videos of your runs will automatically be saved to the folder `part3_ppo/videos/run_name`, with `run_name` being the name set at initialization for your `PPOTrainer` class.

If you get stuck at any point during this implementation, you can look at the solutions or send a message in the Slack channel for help.

In [None]:
class PPOTrainer:
    def __init__(self, args: PPOArgs):
        set_global_seeds(args.seed)
        self.args = args
        self.run_name = f"{args.env_id}__{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
        self.envs = gym.vector.SyncVectorEnv(
            [make_env(idx=idx, run_name=self.run_name, **args.__dict__) for idx in range(args.num_envs)]
        )

        # Define some basic variables from our environment
        self.num_envs = self.envs.num_envs
        self.action_shape = self.envs.single_action_space.shape
        self.obs_shape = self.envs.single_observation_space.shape

        # Create our replay memory
        self.memory = ReplayMemory(
            self.num_envs,
            self.obs_shape,
            self.action_shape,
            args.batch_size,
            args.minibatch_size,
            args.batches_per_learning_phase,
            args.seed,
        )

        # Create our networks & optimizer
        self.actor, self.critic = get_actor_and_critic(self.envs, mode=args.mode)
        self.optimizer, self.scheduler = make_optimizer(self.actor, self.critic, args.total_training_steps, args.lr)

        # Create our agent
        self.agent = PPOAgent(self.envs, self.actor, self.critic, self.memory)

    def rollout_phase(self) -> dict | None:
        """
        This function populates the memory with a new set of experiences, using `self.agent.play_step` to step through
        the environment. It also returns a dict of data which you can include in your progress bar postfix.
        """
        # raise NotImplementedError()
        for _ in range(self.args.num_steps_per_rollout):
            infos = self.agent.play_step()

    def learning_phase(self) -> None:
        """
        This function does the following:
            - Generates minibatches from memory
            - Calculates the objective function, and takes an optimization step based on it
            - Clips the gradients (see detail #11)
            - Steps the learning rate scheduler
        """
        # raise NotImplementedError()
        minibatches = self.agent.get_minibatches(self.args.gamma, self.args.gae_lambda)
        for minibatch in minibatches:
            obj_func = self.compute_ppo_objective(minibatch)
            obj_func.backward() # calculate gradients
            # grad clip
            nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), self.args.max_grad_norm)
            self.optimizer.step() # update weights
            self.optimizer.zero_grad()
        # step the scheduler
        self.scheduler.step()

    def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
        """
        Handles learning phase for a single minibatch. Returns objective function to be maximized.
        """
        # raise NotImplementedError()
        # loss = self.compute_ppo_objective(minibatch)
        logits = self.actor(minibatch.obs)
        dist = Categorical(logits = logits)
        values = self.critic(minibatch.obs).squeeze() # need a vector here, so we collapse  the first dimension

        # calculate loss
        surrogate_obj = calc_clipped_surrogate_objective(dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, clip_coef=self.args.clip_coef)
        vf_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef) # what is minibatch.returns?
        ent_bonus = calc_entropy_bonus(dist, self.args.ent_coef)

        return surrogate_obj - vf_loss + ent_bonus

    def train(self) -> None:
        if self.args.use_wandb:
            wandb.init(
                project=self.args.wandb_project_name,
                entity=self.args.wandb_entity,
                name=self.run_name,
                monitor_gym=self.args.video_log_freq is not None,
            )
            wandb.watch([self.actor, self.critic], log="all", log_freq=50)

        pbar = tqdm(range(self.args.total_phases))
        last_logged_time = time.time()  # so we don't update the progress bar too much

        for phase in pbar:
            data = self.rollout_phase()
            if data is not None and time.time() - last_logged_time > 0.5:
                last_logged_time = time.time()
                pbar.set_postfix(phase=phase, **data)

            self.learning_phase()

        self.envs.close()
        if self.args.use_wandb:
            wandb.finish() 

<details>
<summary>Solution (simpler, no logging)</summary>

```python
def rollout_phase(self) -> dict | None:
    for step in range(self.args.num_steps_per_rollout):
        infos = self.agent.play_step()

def learning_phase(self) -> None:
    minibatches = self.agent.get_minibatches(self.args.gamma, self.args.gae_lambda)
    for minibatch in minibatches:
        objective_fn = self.compute_ppo_objective(minibatch)
        objective_fn.backward()
        nn.utils.clip_grad_norm_(
            list(self.actor.parameters()) + list(self.critic.parameters()), self.args.max_grad_norm
        )
        self.optimizer.step()
        self.optimizer.zero_grad()
    self.scheduler.step()

def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
    logits = self.actor(minibatch.obs)
    dist = Categorical(logits=logits)
    values = self.critic(minibatch.obs).squeeze()

    clipped_surrogate_objective = calc_clipped_surrogate_objective(
        dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, self.args.clip_coef
    )
    value_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef)
    entropy_bonus = calc_entropy_bonus(dist, self.args.ent_coef)

    total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus
    return total_objective_function
```

</details>

<details>
<summary>Solution (full, with logging)</summary>

```python
def rollout_phase(self) -> dict | None:
    data = None
    t0 = time.time()

    for step in range(self.args.num_steps_per_rollout):
        # Play a step, returning the infos dict (containing information for each environment)
        infos = self.agent.play_step()

        # Get data from environments, and log it if some environment did actually terminate
        new_data = get_episode_data_from_infos(infos)
        if new_data is not None:
            data = new_data
            if self.args.use_wandb:
                wandb.log(new_data, step=self.agent.step)

    if self.args.use_wandb:
        wandb.log(
            {"SPS": (self.args.num_steps_per_rollout * self.num_envs) / (time.time() - t0)}, step=self.agent.step
        )

    return data

def learning_phase(self) -> None:
    minibatches = self.agent.get_minibatches(self.args.gamma, self.args.gae_lambda)
    for minibatch in minibatches:
        objective_fn = self.compute_ppo_objective(minibatch)
        objective_fn.backward()
        nn.utils.clip_grad_norm_(
            list(self.actor.parameters()) + list(self.critic.parameters()), self.args.max_grad_norm
        )
        self.optimizer.step()
        self.optimizer.zero_grad()
    self.scheduler.step()

def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
    logits = self.actor(minibatch.obs)
    dist = Categorical(logits=logits)
    values = self.critic(minibatch.obs).squeeze()

    clipped_surrogate_objective = calc_clipped_surrogate_objective(
        dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, self.args.clip_coef
    )
    value_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef)
    entropy_bonus = calc_entropy_bonus(dist, self.args.ent_coef)

    total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus

    with t.inference_mode():
        newlogprob = dist.log_prob(minibatch.actions)
        logratio = newlogprob - minibatch.logprobs
        ratio = logratio.exp()
        approx_kl = (ratio - 1 - logratio).mean().item()
        clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
    if self.args.use_wandb:
        wandb.log(
            dict(
                total_steps=self.agent.step,
                values=values.mean().item(),
                lr=self.scheduler.optimizer.param_groups[0]["lr"],
                value_loss=value_loss.item(),
                clipped_surrogate_objective=clipped_surrogate_objective.item(),
                entropy=entropy_bonus.item(),
                approx_kl=approx_kl,
                clipfrac=np.mean(clipfracs),
            ),
            step=self.agent.step,
        )

    return total_objective_function
```

</details>

Here's some code to run your model on the probe environments (and assert that they're all working fine).

A brief recap of the probe environments, along with recommendations of where to go to debug if one of them fails (note that these won't be true 100% of the time, but should hopefully give you some useful direction):

* **Probe 1 tests basic learning ability**. If this fails, it means the agent has failed to learn to associate a constant observation with a constant reward. You should check your loss functions and optimizers in this case.
* **Probe 2 tests the agent's ability to differentiate between 2 different observations (and learn their respective values)**. If this fails, it means the agent has issues with handling multiple possible observations.
* **Probe 3 tests the agent's ability to handle time & reward delay**. If this fails, it means the agent has problems with multi-step scenarios of discounting future rewards. You should look at how your agent step function works.
* **Probe 4 tests the agent's ability to learn from actions leading to different rewards**. If this fails, it means the agent has failed to change its policy for different rewards, and you should look closer at how your agent is updating its policy based on the rewards it receives & the loss function.
* **Probe 5 tests the agent's ability to map observations to actions**. If this fails, you should look at the code which handles multiple timesteps, as well as the code that handles the agent's map from observations to actions.

In [90]:
def test_probe(probe_idx: int):
    """
    Tests a probe environment by training a network on it & verifying that the value functions are
    in the expected range.
    """
    # Train our network
    args = PPOArgs(
        env_id=f"Probe{probe_idx}-v0",
        wandb_project_name=f"test-probe-{probe_idx}",
        total_timesteps=[7500, 7500, 12500, 20000, 20000][probe_idx - 1],
        lr=0.001,
        video_log_freq=None,
        use_wandb=False,
    )
    trainer = PPOTrainer(args)
    trainer.train()
    agent = trainer.agent

    # Get the correct set of observations, and corresponding values we expect
    obs_for_probes = [[[0.0]], [[-1.0], [+1.0]], [[0.0], [1.0]], [[0.0]], [[0.0], [1.0]]]
    expected_value_for_probes = [[[1.0]], [[-1.0], [+1.0]], [[args.gamma], [1.0]], [[1.0]], [[1.0], [1.0]]]
    expected_probs_for_probes = [None, None, None, [[0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]
    tolerances = [1e-3, 1e-3, 1e-3, 2e-3, 2e-3]
    obs = t.tensor(obs_for_probes[probe_idx - 1]).to(device)

    # Calculate the actual value & probs, and verify them
    with t.inference_mode():
        value = agent.critic(obs)
        probs = agent.actor(obs).softmax(-1)
    expected_value = t.tensor(expected_value_for_probes[probe_idx - 1]).to(device)
    t.testing.assert_close(value, expected_value, atol=tolerances[probe_idx - 1], rtol=0)
    expected_probs = expected_probs_for_probes[probe_idx - 1]
    if expected_probs is not None:
        t.testing.assert_close(probs, t.tensor(expected_probs).to(device), atol=tolerances[probe_idx - 1], rtol=0)
    print("Probe tests passed!\n")


for probe_idx in range(1, 6):
    test_probe(probe_idx)

  7%|▋         | 1/14 [00:00<00:02,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.2500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2419, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2166, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1889, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 14%|█▍        | 2/14 [00:00<00:02,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.0956, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0746, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0363, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 21%|██▏       | 3/14 [00:00<00:02,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(0.0027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0086, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0087, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0081, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0070, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0057, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 29%|██▊       | 4/14 [00:00<00:02,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 5/14 [00:01<00:01,  4.57it/s]

torch.Size([])
torch.Size([128])
tensor(2.7035e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8298e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4688e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3285e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0914e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', 

 43%|████▎     | 6/14 [00:01<00:01,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.1135e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7198e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2330e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0489e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3543e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3402e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2512e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9287e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 50%|█████     | 7/14 [00:01<00:01,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(2.4549e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1052e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9800e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9022e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7279e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7344e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5427e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8086e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 57%|█████▋    | 8/14 [00:01<00:01,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.2633e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6565e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8909e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7040e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0255e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4435e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7419e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8737e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 9/14 [00:01<00:01,  4.57it/s]

torch.Size([])
torch.Size([128])
tensor(1.0552e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2778e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2485e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1893e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0044e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5194e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6543e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4162e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████▏  | 10/14 [00:02<00:00,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(3.6017e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2536e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6033e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0180e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0733e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0253e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9209e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0432e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▊  | 11/14 [00:02<00:00,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.1794e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4540e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5341e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4168e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1498e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0470e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6009e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8726e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 12/14 [00:02<00:00,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(9.0779e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3057e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9982e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4099e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7911e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3714e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2683e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3209e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 13/14 [00:02<00:00,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(2.3309e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0267e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5068e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9620e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1797e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9094e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7922e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2409e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 14/14 [00:03<00:00,  4.57it/s]


torch.Size([])
torch.Size([128])
tensor(6.1334e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5788e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3266e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4387e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8388e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4640e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2758e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2291e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

  7%|▋         | 1/14 [00:00<00:02,  4.55it/s]

torch.Size([])
torch.Size([128])
tensor(0.3278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2916, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2556, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1389, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 14%|█▍        | 2/14 [00:00<00:02,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(0.0015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7160e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0068, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0107, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0116, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 21%|██▏       | 3/14 [00:00<00:02,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7845e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7293e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0010, device='cuda:0', dtype=torch.float64, grad_fn=<Mul

 29%|██▊       | 4/14 [00:00<00:02,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(0.0010, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7460e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1335e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5743e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8464e-05, device='cuda:0', dtype=torch.float6

 36%|███▌      | 5/14 [00:01<00:01,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6204e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2592e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0580e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1451e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1101e-06, device='cuda:0', dtype=t

 43%|████▎     | 6/14 [00:01<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(4.5992e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3416e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4736e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5021e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6310e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3811e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4106e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 50%|█████     | 7/14 [00:01<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(9.3521e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3976e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4438e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4953e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2015e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2486e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5267e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 57%|█████▋    | 8/14 [00:01<00:01,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(1.6575e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3868e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0330e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3230e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2894e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0245e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0055e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4724e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 9/14 [00:01<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(2.2340e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4904e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4515e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7441e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0532e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9591e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8431e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6339e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████▏  | 10/14 [00:02<00:00,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.1268e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7815e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7000e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1506e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0246e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9471e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0033e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5280e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▊  | 11/14 [00:02<00:00,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(1.1184e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1129e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9056e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1185e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0621e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7773e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0479e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4740e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 12/14 [00:02<00:00,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.5282e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7415e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3382e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2483e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2209e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8607e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4948e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8313e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 13/14 [00:02<00:00,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.7610e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7707e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5922e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4953e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2202e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0263e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0331e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8405e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 14/14 [00:03<00:00,  4.53it/s]


torch.Size([])
torch.Size([128])
tensor(1.7091e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1217e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6840e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0935e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3610e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7532e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3796e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8277e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

  4%|▍         | 1/24 [00:00<00:04,  4.78it/s]

torch.Size([])
torch.Size([128])
tensor(0.2659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1136, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

  8%|▊         | 2/24 [00:00<00:04,  4.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.0771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0623, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0597, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0586, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0525, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 12%|█▎        | 3/24 [00:00<00:04,  4.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.0289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0211, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0164, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0147, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 17%|█▋        | 4/24 [00:00<00:04,  4.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0012, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 21%|██        | 5/24 [00:01<00:03,  4.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 25%|██▌       | 6/24 [00:01<00:03,  4.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4891e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5264e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2399e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0066e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7391e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0400e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7816e-06, de

 29%|██▉       | 7/24 [00:01<00:03,  4.79it/s]

torch.Size([])
torch.Size([128])
tensor(1.0879e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4536e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2923e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6346e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6582e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7178e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6864e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 33%|███▎      | 8/24 [00:01<00:03,  4.81it/s]

torch.Size([])
torch.Size([128])
tensor(5.0770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4301e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9823e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1881e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2193e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3645e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1507e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0934e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 38%|███▊      | 9/24 [00:01<00:03,  4.83it/s]

torch.Size([])
torch.Size([128])
tensor(4.4948e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2686e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3748e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6301e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5599e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2489e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4129e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3445e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 42%|████▏     | 10/24 [00:02<00:02,  4.85it/s]

torch.Size([])
torch.Size([128])
tensor(1.0489e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6580e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3660e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2607e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5883e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3124e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9899e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5147e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 46%|████▌     | 11/24 [00:02<00:02,  4.84it/s]

torch.Size([])
torch.Size([128])
tensor(5.0959e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5084e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3701e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2272e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5500e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3507e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8213e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4153e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 50%|█████     | 12/24 [00:02<00:02,  4.85it/s]

torch.Size([])
torch.Size([128])
tensor(4.6869e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9726e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9046e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8422e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0075e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1273e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8558e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1591e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 54%|█████▍    | 13/24 [00:02<00:02,  4.86it/s]

torch.Size([])
torch.Size([128])
tensor(1.7742e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2555e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3191e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9203e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3334e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4843e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5620e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5332e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 58%|█████▊    | 14/24 [00:02<00:02,  4.88it/s]

torch.Size([])
torch.Size([128])
tensor(1.6714e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9834e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1231e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2433e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1634e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9784e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4997e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1728e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▎   | 15/24 [00:03<00:01,  4.88it/s]

torch.Size([])
torch.Size([128])
tensor(1.1996e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2692e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5104e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1337e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0729e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7424e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7177e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2934e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 16/24 [00:03<00:01,  4.88it/s]

torch.Size([])
torch.Size([128])
tensor(7.5222e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5307e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9593e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3210e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7080e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0224e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6784e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9844e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 17/24 [00:03<00:01,  4.89it/s]

torch.Size([])
torch.Size([128])
tensor(8.4674e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0469e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5567e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6872e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3483e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9163e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3837e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0136e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 18/24 [00:03<00:01,  4.89it/s]

torch.Size([])
torch.Size([128])
tensor(7.5244e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6793e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2121e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1818e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1015e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5272e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5087e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6301e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 19/24 [00:03<00:01,  4.89it/s]

torch.Size([])
torch.Size([128])
tensor(7.8123e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2632e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9732e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9269e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5014e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9790e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8817e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0757e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 20/24 [00:04<00:00,  4.89it/s]

torch.Size([])
torch.Size([128])
tensor(8.1310e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8840e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6962e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1860e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3150e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9499e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1074e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0308e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 21/24 [00:04<00:00,  4.88it/s]

torch.Size([])
torch.Size([128])
tensor(6.9801e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2197e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5729e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2734e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2610e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3910e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7982e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6738e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 22/24 [00:04<00:00,  4.88it/s]

torch.Size([])
torch.Size([128])
tensor(7.1301e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2419e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8258e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5476e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6459e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0163e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5434e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0803e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 23/24 [00:04<00:00,  4.89it/s]

torch.Size([])
torch.Size([128])
tensor(5.8473e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8038e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2795e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5133e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0743e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9205e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1702e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6470e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 24/24 [00:04<00:00,  4.86it/s]


torch.Size([])
torch.Size([128])
tensor(6.6432e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8096e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9468e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5301e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7112e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6166e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8262e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5033e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

  3%|▎         | 1/39 [00:00<00:08,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.2500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2492, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2472, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2505, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

  5%|▌         | 2/39 [00:00<00:08,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.2567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2556, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2530, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2517, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2527, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2502, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

  8%|▊         | 3/39 [00:00<00:07,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.2398, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2438, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2396, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2348, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2333, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2367, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 10%|█         | 4/39 [00:00<00:07,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(0.1661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1798, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1494, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1444, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1667, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1252, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 13%|█▎        | 5/39 [00:01<00:07,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(0.0524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0368, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0539, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 15%|█▌        | 6/39 [00:01<00:07,  4.63it/s]

torch.Size([])
torch.Size([128])
tensor(0.0161, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0315, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0371e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0306, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0154, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0229, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0083, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 18%|█▊        | 7/39 [00:01<00:06,  4.62it/s]

torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1485e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3597e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4439e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4974e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch

 21%|██        | 8/39 [00:01<00:06,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7650e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5074e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<Mul

 23%|██▎       | 9/39 [00:01<00:06,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.0209e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4956e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6908e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4183e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7305e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7826e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9088e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4725e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 26%|██▌       | 10/39 [00:02<00:06,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7526e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0485e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch

 28%|██▊       | 11/39 [00:02<00:06,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(3.7980e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4023e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4430e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3216e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3403e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2036e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0681e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1955e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 31%|███       | 12/39 [00:02<00:05,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.2675e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1753e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7797e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5401e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3411e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4720e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9436e-06, de

 33%|███▎      | 13/39 [00:02<00:05,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.0798e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3090e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2361e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1947e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0980e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7089e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2605e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9918e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 36%|███▌      | 14/39 [00:03<00:05,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(1.7770e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3259e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5692e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0064e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9652e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5063e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6278e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0614e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 38%|███▊      | 15/39 [00:03<00:05,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(6.9633e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8818e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0649e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7078e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8756e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4329e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1645e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2687e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 41%|████      | 16/39 [00:03<00:05,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(5.2446e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8932e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1768e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5279e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3859e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7230e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2749e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3697e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 44%|████▎     | 17/39 [00:03<00:04,  4.56it/s]

torch.Size([])
torch.Size([128])
tensor(1.1181e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7226e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1151e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1511e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5297e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8562e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0825e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4168e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 46%|████▌     | 18/39 [00:03<00:04,  4.57it/s]

torch.Size([])
torch.Size([128])
tensor(8.7571e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8642e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3888e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8818e-16, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2742e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2893e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6091e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7873e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 49%|████▊     | 19/39 [00:04<00:04,  4.57it/s]

torch.Size([])
torch.Size([128])
tensor(2.3102e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5667e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2286e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4780e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7326e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3669e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1768e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5967e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 51%|█████▏    | 20/39 [00:04<00:04,  4.58it/s]

torch.Size([])
torch.Size([128])
tensor(3.3427e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9744e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2063e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9918e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2879e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8149e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6459e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0147e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 54%|█████▍    | 21/39 [00:04<00:03,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(5.2879e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9468e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9888e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3313e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8663e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6195e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3669e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1005e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 56%|█████▋    | 22/39 [00:04<00:03,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(6.0504e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9580e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8653e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9584e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1751e-05, device='cuda:0', dtype=t

 59%|█████▉    | 23/39 [00:05<00:03,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(3.5053e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6995e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5116e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9221e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0179e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5116e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0939e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0094e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 24/39 [00:05<00:03,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(5.4257e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0027e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3907e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8642e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0064e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6320e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1677e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3718e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 25/39 [00:05<00:03,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(6.6471e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2731e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2117e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5918e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1305e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0970e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3163e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5330e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 26/39 [00:05<00:02,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(4.7970e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7195e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4680e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0462e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5194e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4568e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4392e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4399e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 27/39 [00:05<00:02,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9753e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4851e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1859e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1507e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3671e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4098e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2836e-06, de

 72%|███████▏  | 28/39 [00:06<00:02,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(1.4051e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2936e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1301e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8806e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3718e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8446e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4673e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8453e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 29/39 [00:06<00:02,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(2.5889e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4454e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9728e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5175e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0126e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0603e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3099e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2392e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 30/39 [00:06<00:01,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.2098e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4434e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9796e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1269e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3465e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7913e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2063e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0831e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 31/39 [00:06<00:01,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(5.6403e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0169e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6529e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4627e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1159e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0398e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4606e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4696e-13, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 32/39 [00:06<00:01,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(4.5788e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2044e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9796e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5690e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0464e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8818e-14, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8631e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2074e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 33/39 [00:07<00:01,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.3928e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2090e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9984e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4773e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9936e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1412e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3158e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7180e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 34/39 [00:07<00:01,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(9.6723e-13, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7160e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5956e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0510e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4640e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5817e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4140e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0449e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 35/39 [00:07<00:00,  4.60it/s]

torch.Size([])
torch.Size([128])
tensor(4.8218e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8149e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3165e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7571e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9936e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5229e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8633e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4469e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 36/39 [00:07<00:00,  4.61it/s]

torch.Size([])
torch.Size([128])
tensor(5.8208e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8663e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1622e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0632e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4502e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8209e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3183e-06, de

 95%|█████████▍| 37/39 [00:08<00:00,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.5075e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1469e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1565e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6287e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3797e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4597e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5809e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1952e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 38/39 [00:08<00:00,  4.59it/s]

torch.Size([])
torch.Size([128])
tensor(7.2685e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2583e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3276e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1662e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7226e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5715e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9083e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2217e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 39/39 [00:08<00:00,  4.60it/s]


torch.Size([])
torch.Size([128])
tensor(1.7329e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6385e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1769e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0345e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2267e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2204e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3685e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2142e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

  3%|▎         | 1/39 [00:00<00:08,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.2497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2562, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2519, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

  5%|▌         | 2/39 [00:00<00:08,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.2549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2470, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2408, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2357, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2482, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

  8%|▊         | 3/39 [00:00<00:07,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.2092, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1870, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2072, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1933, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1888, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1957, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 10%|█         | 4/39 [00:00<00:07,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.1256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0960, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 13%|█▎        | 5/39 [00:01<00:07,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.0901, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0672, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0578, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0644, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 15%|█▌        | 6/39 [00:01<00:07,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.0374, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0402, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0161, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0445, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0305, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 18%|█▊        | 7/39 [00:01<00:07,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2644e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7543e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<Mul

 21%|██        | 8/39 [00:01<00:06,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1047e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5948e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1426e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8970e-05, device='cuda:0', dtype=torch.float6

 23%|██▎       | 9/39 [00:01<00:06,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(5.6526e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4691e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6083e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7774e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9780e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2021e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device

 26%|██▌       | 10/39 [00:02<00:06,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.2170e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1484e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2547e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6200e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7540e-05, device='cuda:0', dtype=t

 28%|██▊       | 11/39 [00:02<00:06,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.0815e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2356e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0809e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8859e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7632e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2244e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8764e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7975e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 31%|███       | 12/39 [00:02<00:05,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(2.0064e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8705e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7282e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1993e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4629e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7094e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5137e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2121e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 33%|███▎      | 13/39 [00:02<00:05,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.6561e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0746e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9891e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4977e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3778e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1392e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device

 36%|███▌      | 14/39 [00:03<00:05,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(8.3826e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9323e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9210e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0958e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1609e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9715e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6422e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 38%|███▊      | 15/39 [00:03<00:05,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(5.3550e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7839e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8260e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9107e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6299e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1423e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0675e-05, de

 41%|████      | 16/39 [00:03<00:05,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1368e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5003e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9814e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9133e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0223e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9175e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device

 44%|████▎     | 17/39 [00:03<00:04,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(1.0272e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0844e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0001e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9149e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4360e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9635e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6222e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6700e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 46%|████▌     | 18/39 [00:03<00:04,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(6.9912e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0134e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9284e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5177e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2383e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2810e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device

 49%|████▊     | 19/39 [00:04<00:04,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(6.7159e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0369e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9088e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6290e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2170e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0657e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9237e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5164e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 51%|█████▏    | 20/39 [00:04<00:04,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(1.1005e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0993e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7182e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1180e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1573e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4124e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3087e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2475e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 54%|█████▍    | 21/39 [00:04<00:03,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(6.8842e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5967e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1815e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1524e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0068e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9312e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3034e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1966e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 56%|█████▋    | 22/39 [00:04<00:03,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(1.0096e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0394e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4053e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7485e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7244e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5151e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9426e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9140e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 23/39 [00:05<00:03,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(6.2373e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3828e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5996e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0970e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7202e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6866e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4511e-05, de

 62%|██████▏   | 24/39 [00:05<00:03,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.0967e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3545e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0436e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6852e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5662e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9421e-06, de

 64%|██████▍   | 25/39 [00:05<00:03,  4.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.2060e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2043e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0304e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5208e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9886e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7621e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1286e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 26/39 [00:05<00:02,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(2.0309e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1113e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6206e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3162e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6600e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1677e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4898e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 27/39 [00:05<00:02,  4.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0732e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9448e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3083e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0387e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5667e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1869e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1534e-05, de

 72%|███████▏  | 28/39 [00:06<00:02,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(4.4827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0630e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2921e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6406e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6693e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6339e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2369e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2478e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 29/39 [00:06<00:02,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(8.6684e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2680e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3761e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5803e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3575e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4914e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4224e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2577e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 30/39 [00:06<00:01,  4.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.8765e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7882e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9093e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4109e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1717e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2694e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1861e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8859e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 31/39 [00:06<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(4.6907e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9988e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1618e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5020e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6091e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4125e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7157e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2512e-08, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 32/39 [00:07<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(1.3585e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2331e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8083e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3294e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7405e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2148e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1857e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8424e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 33/39 [00:07<00:01,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(3.5078e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1113e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2103e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3189e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2963e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6408e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2625e-12, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6422e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 34/39 [00:07<00:01,  4.46it/s]

torch.Size([])
torch.Size([128])
tensor(3.2107e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7860e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7190e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8585e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0470e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5897e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5290e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7137e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 35/39 [00:07<00:00,  4.36it/s]

torch.Size([])
torch.Size([128])
tensor(1.3914e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0219e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7771e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7913e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6117e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3612e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6330e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3442e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 36/39 [00:07<00:00,  4.41it/s]

torch.Size([])
torch.Size([128])
tensor(2.0067e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6752e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7285e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5396e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5485e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4940e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5915e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6248e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 37/39 [00:08<00:00,  4.45it/s]

torch.Size([])
torch.Size([128])
tensor(8.5058e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5863e-11, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0204e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2876e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2260e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3971e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3212e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2124e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 38/39 [00:08<00:00,  4.47it/s]

torch.Size([])
torch.Size([128])
tensor(1.3047e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1699e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3492e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2097e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2236e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1281e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6170e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7406e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 39/39 [00:08<00:00,  4.51it/s]

torch.Size([])
torch.Size([128])
tensor(1.8069e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6909e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8436e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6576e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5832e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5793e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5406e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2985e-10, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])




Once you've passed the tests for all 5 probe environments, you should test your model on Cartpole.

See an example wandb run you should be getting [here](https://api.wandb.ai/links/callum-mcdougall/fdmhh8gq).

In [91]:
args = PPOArgs(use_wandb=True, video_log_freq=50)
trainer = PPOTrainer(args)
trainer.train()

error: XDG_RUNTIME_DIR not set in the environment.
[34m[1mwandb[0m: Currently logged in as: [33mucfntxi[0m ([33mgis[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 1/976 [00:00<05:33,  2.93it/s]

torch.Size([])
torch.Size([128])
tensor(20.0935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.4780, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6759, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.4872, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.0320, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2054, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8897, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1032, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  0%|          | 2/976 [00:00<04:32,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(19.8593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1330, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.8543, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1416, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2963, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8359, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  0%|          | 3/976 [00:00<04:17,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(17.3996, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.4690, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.0738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.9101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.9330, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6048, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.7730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  0%|          | 4/976 [00:01<04:18,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(20.5983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.5215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.9514, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.9893, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1160, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0227, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5502, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0651, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 5/976 [00:01<04:08,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(25.1154, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8263, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6770, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 6/976 [00:01<04:11,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(25.7441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.4682, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8487, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 7/976 [00:01<04:14,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(32.7712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0433, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.5698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6438, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.8717, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 8/976 [00:02<04:06,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(28.2519, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0477, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5120, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0854, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0732, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 9/976 [00:02<04:04,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(31.4082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2456, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6652, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8980, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 10/976 [00:02<04:12,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(25.7327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9104, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1796, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1709, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2611, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 11/976 [00:02<04:34,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(28.6682, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4057, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2296, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3395, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|          | 12/976 [00:03<04:28,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(35.0536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9408, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4474, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.9940, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.3767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|▏         | 13/976 [00:03<04:33,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(33.0605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.8900, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.1163, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.6979, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1460, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4696, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3993, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7989, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  1%|▏         | 14/976 [00:03<04:35,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(29.8226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.4742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.4046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8374, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 15/976 [00:04<04:34,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(36.9120, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.6289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.0757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.7845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.0037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.5644, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.9100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0219, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7779, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 16/976 [00:04<04:37,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(34.7720, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7554, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.3501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2722, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5527, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 17/976 [00:04<04:34,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(27.8404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5562, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 18/976 [00:04<04:22,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(36.3274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.6853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.3122, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.8064, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.1617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.0347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 19/976 [00:05<04:30,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(34.2324, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.3097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.6451, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.1652, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9619, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 20/976 [00:05<04:30,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(40.3763, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.6482, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.6460, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2904, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.6230, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3444, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.8476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.1225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 21/976 [00:05<04:35,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(37.5795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.2050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.5513, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.7382, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.9852, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.6939, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8991, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6982, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 22/976 [00:06<04:36,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(35.8204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.0445, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0694, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.1568, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 23/976 [00:06<04:31,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(33.3354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5279, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.9846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8400, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  2%|▏         | 24/976 [00:06<04:16,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(31.3193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8509, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5440, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4679, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 25/976 [00:06<04:16,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(37.8921, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.6016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.0833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.0051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.7275, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.7970, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.5171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.5022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.2644, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 26/976 [00:07<04:09,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(41.5717, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.8887, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.4986, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.9501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.8837, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.6307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4425, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 27/976 [00:07<04:07,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(30.9248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6921, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 28/976 [00:07<03:59,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(28.5352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5815, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.5363, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1857, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9498, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8131, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8048, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2300, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 29/976 [00:07<04:11,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(36.3348, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.5658, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.2236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0371, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.6985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8280, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1875, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9139, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 30/976 [00:08<04:06,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(32.6464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2191, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7447, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.2764, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.0787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2596, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 31/976 [00:08<04:01,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(36.0779, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.6066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.2708, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.9481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8598, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7785, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 32/976 [00:08<03:57,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(30.7404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7298, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6385, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2492, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 33/976 [00:08<04:01,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(31.8797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3141, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  3%|▎         | 34/976 [00:09<04:01,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(36.6962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.8937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6560, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5166, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.6998, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.0200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▎         | 35/976 [00:09<04:01,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(39.3322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.7312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.8041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.2662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.3095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.1315, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.4317, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▎         | 36/976 [00:09<04:04,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(35.0532, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2980, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.5790, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.6832, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 37/976 [00:09<04:01,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(35.5952, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.5517, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0870, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.8486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.8220, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.3662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 38/976 [00:10<04:04,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(38.5690, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.1710, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2543, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4139, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5978, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.3348, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7376, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.6822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 39/976 [00:10<03:59,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(36.1811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.6488, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7349, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2801, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0533, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 40/976 [00:10<03:55,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(25.9656, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4136, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1114, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6456, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7913, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 41/976 [00:10<03:53,  4.01it/s]

torch.Size([])
torch.Size([128])
tensor(28.6591, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0455, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6298, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7558, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.2392, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0334, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 42/976 [00:11<04:00,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(36.4014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.9845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5637, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2128, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.2423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.1987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  4%|▍         | 43/976 [00:11<04:09,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(34.9913, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4473, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9144, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2897, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2997, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1888, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.6076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▍         | 44/976 [00:11<04:16,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(26.9145, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1242, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3056, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0053, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3435, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0781, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▍         | 45/976 [00:12<04:18,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(31.0115, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7974, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4900, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3412, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6532, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.4693, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▍         | 46/976 [00:12<04:02,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(33.7851, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1768, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.4391, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9133, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▍         | 47/976 [00:12<04:07,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(35.8375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6279, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2942, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0081, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4864, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▍         | 48/976 [00:12<04:10,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(29.0442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1862, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3701, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.8790, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1148, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▌         | 49/976 [00:13<03:59,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(26.4160, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7178, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1778, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4829, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1373, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▌         | 50/976 [00:13<03:54,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(35.2794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5190, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5970, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▌         | 51/976 [00:13<03:47,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(31.8813, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5457, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3349, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9556, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.4258, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9801, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▌         | 52/976 [00:13<03:44,  4.11it/s]

torch.Size([])
torch.Size([128])
tensor(31.4712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2407, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2901, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  5%|▌         | 53/976 [00:14<03:38,  4.23it/s]

torch.Size([])
torch.Size([128])
tensor(32.8500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7667, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4770, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5948, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9575, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 54/976 [00:14<03:40,  4.18it/s]

torch.Size([])
torch.Size([128])
tensor(33.6561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.1909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9428, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7976, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8107, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 55/976 [00:14<03:36,  4.25it/s]

torch.Size([])
torch.Size([128])
tensor(30.2960, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6754, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.3184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2020, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3931, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 56/976 [00:14<03:48,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(29.1787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5197, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 57/976 [00:15<03:44,  4.10it/s]

torch.Size([])
torch.Size([128])
tensor(28.8842, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6897, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2532, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 58/976 [00:15<03:42,  4.13it/s]

torch.Size([])
torch.Size([128])
tensor(33.2807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.1004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 59/976 [00:15<03:51,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(29.3885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4131, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6211, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0440, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▌         | 60/976 [00:15<04:01,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(28.3055, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.0014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0663, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6333, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▋         | 61/976 [00:16<04:04,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(30.5503, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5257, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1190, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.5963, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2358, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6081, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9080, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▋         | 62/976 [00:16<03:53,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(27.3654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7864, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9228, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8509, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5765, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4772, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  6%|▋         | 63/976 [00:16<04:02,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(26.7922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2144, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4588, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3961, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2647, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 64/976 [00:16<04:08,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(30.7789, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3509, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.5587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0295, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9765, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 65/976 [00:17<04:06,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(28.4886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.2065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7622, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 66/976 [00:17<04:00,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(26.4993, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8298, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9427, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6883, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5810, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 67/976 [00:17<03:59,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(28.9649, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3976, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4875, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2907, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3591, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1120, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 68/976 [00:17<03:55,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(28.7089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0208, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3978, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 69/976 [00:18<04:04,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(32.4307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6715, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3830, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0024, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2220, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3102, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 70/976 [00:18<04:19,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(29.7010, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.9035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0098, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4723, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 71/976 [00:19<05:19,  2.83it/s]

torch.Size([])
torch.Size([128])
tensor(25.2735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4505, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.9572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7817, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 72/976 [00:19<06:06,  2.46it/s]

torch.Size([])
torch.Size([128])
tensor(27.9964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3620, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.2789, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  7%|▋         | 73/976 [00:20<06:33,  2.29it/s]

torch.Size([])
torch.Size([128])
tensor(27.3224, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9411, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6892, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9308, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6578, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 74/976 [00:20<08:26,  1.78it/s]

torch.Size([])
torch.Size([128])
tensor(25.6063, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3106, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3172, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 75/976 [00:21<06:59,  2.15it/s]

torch.Size([])
torch.Size([128])
tensor(27.4794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8588, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1414, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0433, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8553, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 76/976 [00:21<05:59,  2.51it/s]

torch.Size([])
torch.Size([128])
tensor(27.6108, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8099, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5294, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2285, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4461, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 77/976 [00:21<05:21,  2.80it/s]

torch.Size([])
torch.Size([128])
tensor(25.7430, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1682, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3819, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.3933, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 78/976 [00:21<04:56,  3.02it/s]

torch.Size([])
torch.Size([128])
tensor(20.9187, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5970, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7760, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7251, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0920, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 79/976 [00:22<04:43,  3.16it/s]

torch.Size([])
torch.Size([128])
tensor(26.5349, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4814, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7139, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4665, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 80/976 [00:22<04:34,  3.26it/s]

torch.Size([])
torch.Size([128])
tensor(27.9854, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0400, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8105, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.2024, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 81/976 [00:22<04:22,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(26.6677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4986, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5522, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3942, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  8%|▊         | 82/976 [00:23<04:16,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(26.5737, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4580, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.8986, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1676, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7094, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▊         | 83/976 [00:23<04:11,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(26.0082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6816, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0098, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▊         | 84/976 [00:23<04:01,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(27.0104, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4806, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.8967, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0805, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.7382, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1164, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▊         | 85/976 [00:23<04:00,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(23.6589, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0723, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 86/976 [00:24<04:03,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(25.4453, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5216, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8943, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 87/976 [00:24<04:03,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(25.1734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9825, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2554, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6585, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9974, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6736, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 88/976 [00:24<03:58,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(31.1529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4887, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9114, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8297, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8722, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.5640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5012, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 89/976 [00:24<03:55,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(30.8011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1276, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1824, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4114, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 90/976 [00:25<03:54,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(28.0052, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9411, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9010, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0950, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6324, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0872, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 91/976 [00:25<03:56,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(22.5277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0941, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1953, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8864, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

  9%|▉         | 92/976 [00:25<03:57,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(24.9035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0325, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7603, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.6053, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0463, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|▉         | 93/976 [00:25<03:56,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(22.6400, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7462, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2304, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2999, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5722, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.8039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|▉         | 94/976 [00:26<03:52,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(23.4340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9451, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7148, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|▉         | 95/976 [00:26<03:56,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(21.6979, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1986, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0105, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1132, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.3099, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|▉         | 96/976 [00:26<04:01,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(21.9319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4265, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3893, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8168, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|▉         | 97/976 [00:27<04:05,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(21.7942, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1916, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4693, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2212, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.6550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1743, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0261, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|█         | 98/976 [00:27<04:04,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(22.8497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4305, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0806, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|█         | 99/976 [00:27<03:51,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(19.4319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5145, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7469, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.8687, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|█         | 100/976 [00:27<03:46,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(22.1351, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.3958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|█         | 101/976 [00:28<03:41,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(24.9921, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5154, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.3374, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.9022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5851, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8410, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.2566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 10%|█         | 102/976 [00:28<03:36,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(24.4041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3427, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0944, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0032, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5321, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 103/976 [00:28<03:38,  4.00it/s]

torch.Size([])
torch.Size([128])
tensor(20.5008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0383, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.4565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 104/976 [00:28<03:36,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(20.2068, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0261, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.8017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1428, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.4858, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 105/976 [00:29<03:42,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(25.9181, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2768, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5179, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8157, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.2017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 106/976 [00:29<03:47,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(23.9805, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0457, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2543, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5943, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8492, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 107/976 [00:29<03:40,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(35.7226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.0173, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4242, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8421, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 108/976 [00:29<03:37,  3.98it/s]

torch.Size([])
torch.Size([128])
tensor(20.4987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.4787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1178, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6972, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2839, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█         | 109/976 [00:30<03:41,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(28.0320, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4047, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█▏        | 110/976 [00:30<03:35,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(25.4853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4358, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6694, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8600, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8588, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█▏        | 111/976 [00:30<03:42,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(23.9752, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.5076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5981, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7110, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.2616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 11%|█▏        | 112/976 [00:30<03:50,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(24.0735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6633, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5585, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4893, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0155, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 113/976 [00:31<03:55,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(19.4249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3847, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2957, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 114/976 [00:31<04:05,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(34.6710, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.6983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.4501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4285, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.1739, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6673, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.2649, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.3545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 115/976 [00:31<04:07,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(33.0200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9120, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.0248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.1132, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.2704, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7227, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 116/976 [00:32<04:10,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(20.4735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9563, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.8269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 117/976 [00:32<04:07,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(24.4883, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.3948, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.0752, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.7990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5603, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.9551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.9802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 118/976 [00:32<04:18,  3.32it/s]

torch.Size([])
torch.Size([128])
tensor(29.5905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2293, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.3505, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5967, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 119/976 [00:33<04:19,  3.30it/s]

torch.Size([])
torch.Size([128])
tensor(28.3175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.9103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.2151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.2111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6688, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3330, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.0590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.5408, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 120/976 [00:33<04:07,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(41.3183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.6698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.0185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.8151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▏        | 121/976 [00:33<03:55,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(22.2942, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7552, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1874, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1629, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9813, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7409, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 12%|█▎        | 122/976 [00:33<04:01,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(20.2520, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8411, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7240, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.9909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1978, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5914, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 123/976 [00:34<03:54,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(25.5645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1331, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.4747, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.3475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.8958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.5317, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 124/976 [00:34<03:43,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(16.8905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2768, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 125/976 [00:34<03:51,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(34.3535, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0993, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.5221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.0975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 126/976 [00:34<03:45,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(30.8256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2373, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7609, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6242, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6961, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0904, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 127/976 [00:35<03:44,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(35.2074, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1449, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4503, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9620, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 128/976 [00:35<03:51,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(30.4308, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3790, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.0558, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.1332, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 129/976 [00:35<03:57,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(26.4117, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.7512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9417, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3667, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.7672, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0240, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 130/976 [00:36<03:59,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(24.5733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1814, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5391, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1837, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.1449, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6533, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 13%|█▎        | 131/976 [00:36<03:57,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(20.7746, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0715, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1485, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8848, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.2617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.7659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1446, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▎        | 132/976 [00:36<03:51,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(34.0390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6965, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0320, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3172, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▎        | 133/976 [00:36<03:53,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(17.0090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8547, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2121, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9607, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.1409, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0296, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▎        | 134/976 [00:37<03:56,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(17.8260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.1168, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.2340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.1195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.7407, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.6032, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.4943, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 135/976 [00:37<04:02,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(19.0301, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1054, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.5537, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.5232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9449, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.9049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.4514, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 136/976 [00:37<04:04,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(18.8515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8961, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.2407, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.6846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6214, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.0572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 137/976 [00:38<04:05,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(20.5004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.0633, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4121, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.4179, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.1491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.7421, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.5575, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 138/976 [00:38<04:04,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(12.2872, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.3088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.4977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.5714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.4746, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7362, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6999, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 139/976 [00:38<03:54,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(15.1112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.2126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.5055, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.7529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.3549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.7866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.1826, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.6689, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.6363, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 140/976 [00:38<03:53,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(17.5608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.8392, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.3548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8143, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.3359, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4685, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.8796, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.0391, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 14%|█▍        | 141/976 [00:39<03:39,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(14.9442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.8401, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4854, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7991, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2702, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.9616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7872, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▍        | 142/976 [00:39<03:40,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(27.6187, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.5733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8663, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▍        | 143/976 [00:39<03:47,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(14.9548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3998, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.9785, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3710, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▍        | 144/976 [00:39<03:59,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(44.1564, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.3849, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9824, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.0141, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.6125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▍        | 145/976 [00:40<04:06,  3.37it/s]

torch.Size([])
torch.Size([128])
tensor(22.2095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.6314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.6078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4539, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.5247, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.3178, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8884, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4995, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▍        | 146/976 [00:40<03:54,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(14.0619, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6619, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1422, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.0429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.5234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8371, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.5076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8880, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▌        | 147/976 [00:40<03:55,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(46.8807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.2355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.5902, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.9529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.0661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.8799, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.7508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▌        | 148/976 [00:41<03:54,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(14.1846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.3055, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8632, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4427, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.6131, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.0469, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.5572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.2566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▌        | 149/976 [00:41<03:53,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(13.4622, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.0425, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3301, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.9628, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.9195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.1015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4040, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.1049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▌        | 150/976 [00:41<03:39,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(17.1222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.1337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.9604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.0809, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.9357, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8827, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.8031, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 15%|█▌        | 151/976 [00:41<03:33,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(13.4222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.0079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.9063, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.2681, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.1908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8211, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.3235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.2926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.6247, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 152/976 [00:42<03:27,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(26.2541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5643, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.7658, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5996, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0832, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 153/976 [00:42<03:27,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(43.0494, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9326, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.4563, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7311, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2643, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2115, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4227, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 154/976 [00:42<03:36,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(29.6908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3093, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0691, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0324, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.1657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.4433, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 155/976 [00:42<03:44,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(12.7525, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.5423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.1527, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.9598, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.4807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 156/976 [00:43<03:46,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(12.5528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9168, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.3599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.6554, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.9847, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.1820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 157/976 [00:43<03:45,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(11.2671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.5438, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3092, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1034, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.5156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6892, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0910, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.5046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▌        | 158/976 [00:43<03:47,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(18.4275, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4255, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.7372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7754, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.5645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.8340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8966, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▋        | 159/976 [00:44<03:49,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(12.2273, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0753, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.2292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.0987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2220, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.4940, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▋        | 160/976 [00:44<03:59,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(11.9654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.4436, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.6881, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.6448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.9292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.6038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 16%|█▋        | 161/976 [00:44<04:04,  3.33it/s]

torch.Size([])
torch.Size([128])
tensor(19.8274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9619, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.4128, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2781, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8611, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.3958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.3075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.9924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 162/976 [00:44<04:03,  3.34it/s]

torch.Size([])
torch.Size([128])
tensor(11.7919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.5479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.4364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.8167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.0051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.6256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.2486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 163/976 [00:45<03:51,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(22.1334, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1147, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.8793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0799, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.4315, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2775, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.6823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.1387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 164/976 [00:45<03:57,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(10.8452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.0943, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.9796, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.7105, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.0719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.5360, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 165/976 [00:45<03:49,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(15.6203, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7287, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.1394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2241, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 166/976 [00:46<03:52,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(22.8867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9669, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.5125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4229, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.8206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.5496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.2833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 167/976 [00:46<03:48,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(11.2122, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.6065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8470, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.9989, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.9733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.4004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8948, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.7541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 168/976 [00:46<03:55,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(18.3757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6444, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.1906, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1884, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.2616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1134, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.0428, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 169/976 [00:46<03:56,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(11.0068, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6895, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8967, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.7307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.7461, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6774, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6980, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 17%|█▋        | 170/976 [00:47<03:47,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(11.1741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8285, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6759, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8660, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.5940, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8796, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 18%|█▊        | 171/976 [00:47<03:46,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(10.4366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.6379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.7921, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2285, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.4235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 18%|█▊        | 172/976 [00:47<03:43,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(10.9561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2674, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.9853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0069, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 18%|█▊        | 173/976 [00:48<03:34,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(28.1996, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.4015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.6232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.8242, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.3803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.9312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8391, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 18%|█▊        | 174/976 [00:48<03:38,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(9.9593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.5958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2349, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2960, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch

 18%|█▊        | 175/976 [00:48<03:31,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(10.3568, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8400, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9350, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch

 18%|█▊        | 176/976 [00:48<03:31,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(10.1082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.4748, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7688, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5330, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.S

 18%|█▊        | 177/976 [00:49<03:32,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(10.3227, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7063, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8224, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0702, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Si

 18%|█▊        | 178/976 [00:49<03:25,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(9.9728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5346, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6128, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4272, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3963, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1351, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 18%|█▊        | 179/976 [00:49<03:35,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(9.7771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4474, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7201, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2603, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 18%|█▊        | 180/976 [00:50<04:26,  2.99it/s]

torch.Size([])
torch.Size([128])
tensor(9.8066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5831, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3300, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3514, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3763, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9988, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▊        | 181/976 [00:50<05:05,  2.60it/s]

torch.Size([])
torch.Size([128])
tensor(9.7068, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4087, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9609, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▊        | 182/976 [00:51<05:34,  2.38it/s]

torch.Size([])
torch.Size([128])
tensor(9.4719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9722, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 183/976 [00:51<05:52,  2.25it/s]

torch.Size([])
torch.Size([128])
tensor(9.4226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0573, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 184/976 [00:52<06:55,  1.91it/s]

torch.Size([])
torch.Size([128])
tensor(8.5035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3054, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0861, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3083, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0574, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7973, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 185/976 [00:52<06:02,  2.18it/s]

torch.Size([])
torch.Size([128])
tensor(8.7252, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5130, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7564, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 186/976 [00:52<05:25,  2.43it/s]

torch.Size([])
torch.Size([128])
tensor(8.6416, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6651, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9402, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2518, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4870, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5189, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 187/976 [00:53<04:54,  2.68it/s]

torch.Size([])
torch.Size([128])
tensor(8.3867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4229, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5544, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1074, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3259, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 188/976 [00:53<04:38,  2.82it/s]

torch.Size([])
torch.Size([128])
tensor(8.4318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6929, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2115, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2864, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1395, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7956, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0704, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3988, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 189/976 [00:53<04:23,  2.98it/s]

torch.Size([])
torch.Size([128])
tensor(8.5426, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9380, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1827, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5535, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4584, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0691, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 19%|█▉        | 190/976 [00:54<04:05,  3.20it/s]

torch.Size([])
torch.Size([128])
tensor(8.6904, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8715, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0665, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 20%|█▉        | 191/976 [00:54<03:56,  3.32it/s]

torch.Size([])
torch.Size([128])
tensor(7.8314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3416, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2110, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9731, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 20%|█▉        | 192/976 [00:54<03:52,  3.37it/s]

torch.Size([])
torch.Size([128])
tensor(11.6282, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.5857, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.3747, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5629, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.0873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7945, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.3357, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4999, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 20%|█▉        | 193/976 [00:54<03:56,  3.31it/s]

torch.Size([])
torch.Size([128])
tensor(8.3341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8433, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4580, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9939, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 20%|█▉        | 194/976 [00:55<03:56,  3.30it/s]

torch.Size([])
torch.Size([128])
tensor(7.9315, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6487, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9145, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8577, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0674, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4445, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 20%|█▉        | 195/976 [00:55<03:50,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(20.8769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6910, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6810, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4547, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8316, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8467, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1473, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0048, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 20%|██        | 196/976 [00:55<03:48,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(19.9318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4858, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0756, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1946, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 20%|██        | 197/976 [00:56<03:48,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(7.3017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5367, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5026, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3857, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3070, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2409, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 20%|██        | 198/976 [00:56<03:40,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(26.9896, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8294, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.9602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.6882, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6856, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.9309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 20%|██        | 199/976 [00:56<03:34,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(35.8943, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.8838, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.3977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.0678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2660, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.1925, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 20%|██        | 200/976 [00:56<03:24,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(15.1714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1880, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6392, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.0059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7182, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7026, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██        | 201/976 [00:57<03:16,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(30.3490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7621, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.9810, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.5273, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.2595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7972, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2527, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.1804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██        | 202/976 [00:57<03:12,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(14.7739, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.9006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3972, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9276, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.1498, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.8930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0813, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 21%|██        | 203/976 [00:57<03:12,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(50.3241, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.3495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(74.2466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.3685, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.7457, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.1258, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.8844, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.6716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.5150, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██        | 204/976 [00:57<03:10,  4.05it/s]

torch.Size([])
torch.Size([128])
tensor(7.1911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0311, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8681, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 21%|██        | 205/976 [00:58<03:19,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(73.5680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.4628, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.9296, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.9834, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.9727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.0984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.7589, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.1493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.4269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██        | 206/976 [00:58<03:25,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(35.2490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.9196, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.6697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.2163, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.1373, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5146, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██        | 207/976 [00:58<03:33,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(27.8738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.3618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.8885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.6045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.6096, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2304, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██▏       | 208/976 [00:58<03:30,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(24.0933, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.7712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5982, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1230, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9685, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8875, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 21%|██▏       | 209/976 [00:59<03:31,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(16.4525, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6371, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7681, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.6998, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.0180, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.4288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.4124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 22%|██▏       | 210/976 [00:59<03:36,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(18.0500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.1240, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2890, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.1833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.3729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4533, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 22%|██▏       | 211/976 [00:59<03:38,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(6.7278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5925, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6306, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4842, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 22%|██▏       | 212/976 [01:00<03:40,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(30.9964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.6412, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.4281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8071, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6674, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.2353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 22%|██▏       | 213/976 [01:00<03:46,  3.37it/s]

torch.Size([])
torch.Size([128])
tensor(27.0379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1609, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.2038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8830, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4869, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.3892, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6667, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 22%|██▏       | 214/976 [01:00<03:41,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(6.5565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1784, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2744, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9952, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0389, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 22%|██▏       | 215/976 [01:01<03:39,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(23.9365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1925, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.5911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.0702, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6934, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.0643, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2980, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 22%|██▏       | 216/976 [01:01<03:44,  3.39it/s]

torch.Size([])
torch.Size([128])
tensor(24.7148, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.6434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.4946, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.3223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.5415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0972, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 22%|██▏       | 217/976 [01:01<03:38,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(5.8885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1471, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0449, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7136, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7632, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8696, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 22%|██▏       | 218/976 [01:01<03:29,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(5.8850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0447, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 22%|██▏       | 219/976 [01:02<03:34,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(6.0050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9255, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8726, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5458, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 220/976 [01:02<03:36,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(5.6006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7607, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0121, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6179, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6371, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3539, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6056, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 221/976 [01:02<03:36,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(5.6232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3280, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 222/976 [01:03<03:39,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(20.9277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.0962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.0853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.4837, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.4772, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.8454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2808, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8030, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 23%|██▎       | 223/976 [01:03<03:40,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(5.6055, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5623, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4063, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2359, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0848, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 224/976 [01:03<03:40,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(5.1965, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6966, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 225/976 [01:03<03:40,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(5.5690, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0781, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5302, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4781, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3469, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2585, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 226/976 [01:04<03:42,  3.37it/s]

torch.Size([])
torch.Size([128])
tensor(5.1990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2643, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0903, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2208, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7760, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8949, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 227/976 [01:04<03:40,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(10.4188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4788, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.6329, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4430, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.6124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4910, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.0921, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torc

 23%|██▎       | 228/976 [01:04<03:31,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(5.0070, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8247, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0402, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3827, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0720, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0621, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0814, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 23%|██▎       | 229/976 [01:05<03:29,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(5.0668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8596, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8736, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6748, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 24%|██▎       | 230/976 [01:05<03:27,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(41.4546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3952, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1085, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.9508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.5386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 24%|██▎       | 231/976 [01:05<03:32,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(5.6237, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0210, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3138, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 24%|██▍       | 232/976 [01:05<03:25,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(9.5797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2020, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.3279, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Si

 24%|██▍       | 233/976 [01:06<03:21,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(40.8342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.6199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.0382, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.9313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8633, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.8246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.4666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.9813, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 24%|██▍       | 234/976 [01:06<03:25,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(5.0102, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7806, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3945, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5842, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8064, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 24%|██▍       | 235/976 [01:06<03:32,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(4.6500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6398, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4817, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 24%|██▍       | 236/976 [01:07<03:33,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(4.6683, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6932, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2477, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 24%|██▍       | 237/976 [01:07<03:33,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(14.6675, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.2307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9316, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7412, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.9115, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2271, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 24%|██▍       | 238/976 [01:07<03:31,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(42.8496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4835, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7644, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1316, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2554, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.8771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.9246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7860, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
tor

 24%|██▍       | 239/976 [01:07<03:26,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(8.7796, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.0056, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7123, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.3927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
tor

 25%|██▍       | 240/976 [01:08<03:13,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(45.8741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.8835, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.0730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.7010, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.7660, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.1135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.9142, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.8697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▍       | 241/976 [01:08<03:16,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(29.1280, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.4498, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0241, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.7353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4321, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▍       | 242/976 [01:08<03:18,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(4.0642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1810, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9974, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 25%|██▍       | 243/976 [01:08<03:25,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(72.0531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.3323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(74.9245, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(104.4250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(93.2573, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.0654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.6207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.9081, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.3020, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 25%|██▌       | 244/976 [01:09<03:30,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(16.0734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.2960, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.3004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▌       | 245/976 [01:09<03:29,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(34.7491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2896, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.5179, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.6049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2989, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3482, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7991, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2563, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▌       | 246/976 [01:09<03:31,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(55.5922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.8675, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.8214, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.4600, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.4932, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.8007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.9216, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.6638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.1642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▌       | 247/976 [01:10<03:31,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(15.9376, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7598, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5426, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4606, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0413, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.2462, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.9114, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4883, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3589, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 25%|██▌       | 248/976 [01:10<03:26,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(34.1387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(81.9646, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.7962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.3885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(78.2424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.8659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.3362, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.7041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.2310, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▌       | 249/976 [01:10<03:14,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(8.9263, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8806, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.2085, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.6257, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.2167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 26%|██▌       | 250/976 [01:10<03:09,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(45.3861, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1458, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.8551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.1262, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.2196, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.1676, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.9709, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.4540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.1236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▌       | 251/976 [01:11<03:05,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(6.1713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.7497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.9600, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6417, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8701, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
tor

 26%|██▌       | 252/976 [01:11<03:01,  4.00it/s]

torch.Size([])
torch.Size([128])
tensor(67.5441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.6470, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.3264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.9394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.3447, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.5704, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.6019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.3387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.8666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▌       | 253/976 [01:11<02:57,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(36.4731, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.5765, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5258, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.3275, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.6346, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4071, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.5812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.1705, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.0508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▌       | 254/976 [01:11<03:00,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(4.0044, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1220, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9747, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1463, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 26%|██▌       | 255/976 [01:12<02:57,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(53.8697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.0912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.4540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.0552, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.8587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.8005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.0216, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.6819, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.8268, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▌       | 256/976 [01:12<02:57,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(3.8620, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7835, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9947, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8753, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0361, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 26%|██▋       | 257/976 [01:12<03:03,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(22.7894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0773, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.7595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.9904, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.3079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 26%|██▋       | 258/976 [01:12<03:08,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(32.4397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.7443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.0065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.6641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.7861, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.5252, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6537, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.9729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 259/976 [01:13<03:11,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(4.0192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8298, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0825, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7819, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 27%|██▋       | 260/976 [01:13<03:18,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(31.8339, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.0332, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.7693, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.1819, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.8342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.4493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.0748, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 261/976 [01:13<03:20,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(27.9071, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9358, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.5234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.2224, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8230, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.6247, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.9770, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.7188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 262/976 [01:14<03:24,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(46.0193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.2912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.2583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.0680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.4003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.9762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.1446, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 263/976 [01:14<03:27,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(9.9258, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3945, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.8191, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.7581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7132, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.3546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 27%|██▋       | 264/976 [01:14<03:24,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(42.3323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.9737, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7133, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.2577, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.6043, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.0063, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.8058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.3933, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 265/976 [01:14<03:27,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(52.0775, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8764, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.0342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.4121, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3779, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.8069, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.1250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 266/976 [01:15<03:30,  3.38it/s]

torch.Size([])
torch.Size([128])
tensor(3.7786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7917, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9107, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7875, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 27%|██▋       | 267/976 [01:15<03:25,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(46.5766, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.1631, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.7260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.3719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.5223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.5284, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 27%|██▋       | 268/976 [01:15<03:20,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(16.9601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.5078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.4876, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7212, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.6552, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 269/976 [01:16<03:18,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(51.4596, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.6105, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.3244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.2820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.8069, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.6281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4538, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 270/976 [01:16<03:18,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(34.6529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.5058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.3406, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.1676, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7739, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.8735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2895, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.1358, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 28%|██▊       | 271/976 [01:16<03:21,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(42.7973, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.1859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.4205, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3612, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.9288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.1217, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.4530, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 272/976 [01:16<03:14,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(13.0830, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.0130, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.6838, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.4005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.9610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 273/976 [01:17<03:10,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(60.1740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.4838, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.5444, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.3867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.7700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.0507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.8149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(93.0186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(51.5995, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 274/976 [01:17<03:12,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(45.2162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.9899, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.6752, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.0594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.9096, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.6502, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8865, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.3762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 275/976 [01:17<03:19,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(3.9754, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9805, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8585, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0519, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7525, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 28%|██▊       | 276/976 [01:17<03:16,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(17.0636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.0371, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.2518, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.7478, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.1792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.3599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(51.5880, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 28%|██▊       | 277/976 [01:18<03:12,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(3.8325, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8731, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8361, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9614, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8272, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 28%|██▊       | 278/976 [01:18<03:21,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(23.3620, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.9186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.5833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.6674, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.1352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6544, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.8975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 29%|██▊       | 279/976 [01:18<03:12,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(36.5047, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.4272, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.5193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.9624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4790, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.4419, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 29%|██▊       | 280/976 [01:19<03:15,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(3.7684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4895, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4647, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4432, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 29%|██▉       | 281/976 [01:19<03:22,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(37.1928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.7172, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.3999, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.0304, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.8424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.0311, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.1122, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 29%|██▉       | 282/976 [01:19<03:25,  3.37it/s]

torch.Size([])
torch.Size([128])
tensor(14.3199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.9318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.2082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.5169, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5401, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6578, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.5851, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torc

 29%|██▉       | 283/976 [01:20<03:18,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(3.4574, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4875, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6817, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4895, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3955, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3746, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4492, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 29%|██▉       | 284/976 [01:20<03:10,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(3.5446, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2600, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3107, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3146, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 29%|██▉       | 285/976 [01:20<03:09,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(23.1879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.9645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.2565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.3919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.0783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 29%|██▉       | 286/976 [01:20<03:13,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(3.3256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3336, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2814, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1339, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 29%|██▉       | 287/976 [01:21<03:17,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(3.4442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2165, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1643, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2243, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9161, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 30%|██▉       | 288/976 [01:21<03:21,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(42.0508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.7894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.1876, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.9259, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.5189, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.4167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2865, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.3904, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|██▉       | 289/976 [01:21<03:19,  3.44it/s]

torch.Size([])
torch.Size([128])
tensor(28.1340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.2733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.5900, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2857, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.2372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.1290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.1850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.9744, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|██▉       | 290/976 [01:21<03:13,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(3.1877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2322, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1571, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0902, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9899, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 30%|██▉       | 291/976 [01:22<03:10,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(45.5677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0622, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.1460, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.3684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.7246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.8899, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.5253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5751, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|██▉       | 292/976 [01:22<03:08,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(74.7420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.6698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4841, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.0608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.6408, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.3314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.5448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8523, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|███       | 293/976 [01:22<03:06,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(12.4378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.9592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.1109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.0687, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.4253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|███       | 294/976 [01:23<03:03,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(48.2023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.6567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.5365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.1288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.9395, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.8331, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.4355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.4735, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|███       | 295/976 [01:23<03:05,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(19.5377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.1088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.7007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.4537, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.2292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.8809, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|███       | 296/976 [01:23<03:10,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(21.5673, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.2792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6080, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.4024, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.6466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.0912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.6335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.5834, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 30%|███       | 297/976 [01:23<03:12,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.1987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1447, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0669, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8906, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9230, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 31%|███       | 298/976 [01:24<03:07,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(36.5377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1073, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.6264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.9981, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.7302, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.9281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.8092, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.1696, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 31%|███       | 299/976 [01:24<02:58,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(2.8917, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9317, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0345, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0073, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8765, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 31%|███       | 300/976 [01:24<02:53,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(2.8341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8691, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7190, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7407, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7218, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 31%|███       | 301/976 [01:24<02:50,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(64.8185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.6539, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.4355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.8913, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.3039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.1231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.0478, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 31%|███       | 302/976 [01:25<02:49,  3.98it/s]

torch.Size([])
torch.Size([128])
tensor(2.8327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8651, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7970, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 31%|███       | 303/976 [01:25<02:54,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(27.3455, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.5873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.7027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.4066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.6260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0176, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 31%|███       | 304/976 [01:25<03:01,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(16.6951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6370, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.9596, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2061, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8710, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3421, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.9082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 31%|███▏      | 305/976 [01:25<03:01,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(42.0219, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.7318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.2068, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.9677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.7867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.0540, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.5289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.5278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.4464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 31%|███▏      | 306/976 [01:26<02:51,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(2.7296, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6517, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5478, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 31%|███▏      | 307/976 [01:26<02:54,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(36.6181, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.0217, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.6954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2816, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.3659, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 32%|███▏      | 308/976 [01:26<02:55,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(54.6740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5848, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7297, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4463, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.0670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.6629, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.8627, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.7792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.2082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
tor

 32%|███▏      | 309/976 [01:26<02:51,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(2.4576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5902, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4710, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4061, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4217, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3862, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 310/976 [01:27<02:51,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(2.4618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3144, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3346, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 311/976 [01:27<02:56,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(2.4060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3363, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3914, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1705, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2755, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 312/976 [01:27<03:02,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(15.0464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2255, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.4031, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.8911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.6042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.6385, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 32%|███▏      | 313/976 [01:28<03:07,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.3253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2245, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1211, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 314/976 [01:28<03:06,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(2.3470, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2145, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1174, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2010, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0219, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 315/976 [01:28<03:02,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.1345, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1752, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0683, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0666, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0682, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9948, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 316/976 [01:28<03:03,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(2.1045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0777, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0844, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9753, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0091, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 32%|███▏      | 317/976 [01:29<02:56,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(2.1312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0409, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8910, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0121, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9502, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8770, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 318/976 [01:29<02:55,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(2.1266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9707, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9698, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9518, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8245, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 319/976 [01:29<02:55,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(19.9223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.8415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.4167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.5639, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7731, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.2558, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.4329, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.3147, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 33%|███▎      | 320/976 [01:30<02:59,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(1.8901, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8672, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7370, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7808, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7556, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 321/976 [01:30<03:02,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.8164, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7321, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7487, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7626, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5854, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6093, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 322/976 [01:30<03:06,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(1.8049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7177, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6350, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6951, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6360, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6757, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6040, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 323/976 [01:30<03:02,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(1.7640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5563, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6494, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 324/976 [01:31<02:55,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(1.7418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5788, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4691, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5411, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 325/976 [01:31<03:14,  3.35it/s]

torch.Size([])
torch.Size([128])
tensor(1.6406, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5848, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4676, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4618, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 33%|███▎      | 326/976 [01:31<03:45,  2.89it/s]

torch.Size([])
torch.Size([128])
tensor(1.5939, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4720, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4707, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▎      | 327/976 [01:32<04:07,  2.63it/s]

torch.Size([])
torch.Size([128])
tensor(1.5842, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4279, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4689, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▎      | 328/976 [01:32<04:22,  2.47it/s]

torch.Size([])
torch.Size([128])
tensor(1.4378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4485, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4282, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3726, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3621, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3518, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3788, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▎      | 329/976 [01:33<05:32,  1.95it/s]

torch.Size([])
torch.Size([128])
tensor(1.4171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3830, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4172, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3626, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3285, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3594, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3174, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 330/976 [01:33<04:47,  2.25it/s]

torch.Size([])
torch.Size([128])
tensor(1.3912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3775, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2893, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3352, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3190, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2623, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2306, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 331/976 [01:34<04:13,  2.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.3049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3421, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2810, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 332/976 [01:34<03:49,  2.81it/s]

torch.Size([])
torch.Size([128])
tensor(1.2656, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2417, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2057, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1416, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 333/976 [01:34<03:32,  3.02it/s]

torch.Size([])
torch.Size([128])
tensor(1.2062, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2187, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1832, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1288, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1373, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 334/976 [01:35<03:23,  3.16it/s]

torch.Size([])
torch.Size([128])
tensor(1.1935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1034, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1106, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0876, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1591, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 335/976 [01:35<03:17,  3.24it/s]

torch.Size([])
torch.Size([128])
tensor(1.1533, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1070, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0445, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0772, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 34%|███▍      | 336/976 [01:35<03:14,  3.28it/s]

torch.Size([])
torch.Size([128])
tensor(1.1150, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1229, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0520, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▍      | 337/976 [01:35<03:10,  3.36it/s]

torch.Size([])
torch.Size([128])
tensor(1.0587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0300, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9361, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9368, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▍      | 338/976 [01:36<02:57,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.0313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0614, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9798, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9522, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9234, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▍      | 339/976 [01:36<02:58,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.0079, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9669, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9609, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8927, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8748, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▍      | 340/976 [01:36<02:54,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(0.9615, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9477, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9028, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8760, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8435, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8629, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▍      | 341/976 [01:36<02:54,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(0.9334, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8563, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8169, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7883, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▌      | 342/976 [01:37<02:53,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(0.9005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8779, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▌      | 343/976 [01:37<02:43,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(0.8350, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8430, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7959, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7387, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▌      | 344/976 [01:37<02:49,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(0.7897, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7473, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7798, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7513, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7040, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7165, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7236, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▌      | 345/976 [01:38<02:52,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(0.7499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7764, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6878, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7130, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6569, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6751, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 35%|███▌      | 346/976 [01:38<02:57,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.7507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7164, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 347/976 [01:38<02:58,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.7198, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7093, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6154, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6228, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6346, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 348/976 [01:38<02:58,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(0.6425, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6033, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6161, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5896, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5907, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 349/976 [01:39<02:54,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(0.6602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6458, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5689, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5852, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5751, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 350/976 [01:39<02:44,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.5973, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5646, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5663, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5229, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5052, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 351/976 [01:39<02:40,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(0.5866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5569, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5123, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4959, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4777, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 352/976 [01:39<02:33,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(0.5262, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5062, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4580, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▌      | 353/976 [01:40<02:30,  4.13it/s]

torch.Size([])
torch.Size([128])
tensor(0.5356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4654, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4463, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4401, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▋      | 354/976 [01:40<02:30,  4.12it/s]

torch.Size([])
torch.Size([128])
tensor(0.5025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4575, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4658, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4467, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4286, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4074, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4173, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 36%|███▋      | 355/976 [01:40<02:38,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(48.2474, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(77.9711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.1327, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.1392, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.0641, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.9957, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.3447, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.2516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(86.1681, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 36%|███▋      | 356/976 [01:40<02:43,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.4623, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4614, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4665, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 357/976 [01:41<02:37,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(0.4528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4241, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4263, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3991, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 358/976 [01:41<02:36,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(78.5393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.9690, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.7724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.4110, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7194, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.2734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.1930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.5212, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 37%|███▋      | 359/976 [01:41<02:38,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(0.4368, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4398, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4632, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4527, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 360/976 [01:41<02:46,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(0.4423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4106, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3799, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 361/976 [01:42<02:49,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(39.5369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.0476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.4867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.1662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.3801, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.7834, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 37%|███▋      | 362/976 [01:42<02:51,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(0.3846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4194, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4237, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4273, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4123, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4162, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4094, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 363/976 [01:42<02:53,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(0.4215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3913, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3553, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3612, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 364/976 [01:43<02:55,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(0.3634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3586, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3588, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3216, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 37%|███▋      | 365/976 [01:43<02:51,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.3582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3198, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3165, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3055, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 366/976 [01:43<02:45,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(0.3331, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3042, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2977, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3020, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2761, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 367/976 [01:43<02:45,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(0.3078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2905, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2763, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2742, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2646, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 368/976 [01:44<02:39,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.2912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2704, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2494, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2362, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 369/976 [01:44<02:46,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(0.2605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2385, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2427, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2382, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2030, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 370/976 [01:44<02:48,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(0.2584, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2400, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2333, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2174, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1996, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1876, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 371/976 [01:45<02:50,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(0.2377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2112, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2187, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2057, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2031, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1821, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 372/976 [01:45<02:51,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.2201, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1988, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1874, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1805, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1652, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 373/976 [01:45<02:50,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(0.1803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1991, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1920, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1799, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1717, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1613, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1622, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1532, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 374/976 [01:45<02:43,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(0.1833, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1737, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1552, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1332, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 38%|███▊      | 375/976 [01:46<02:46,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(0.1648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1593, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1396, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1380, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1347, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1257, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1254, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▊      | 376/976 [01:46<02:47,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(0.1547, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1273, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1251, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▊      | 377/976 [01:46<02:47,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(0.1336, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1263, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1143, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1095, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0978, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0929, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▊      | 378/976 [01:46<02:40,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(0.1213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1110, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0952, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0965, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0887, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0764, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 379/976 [01:47<02:44,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(0.1076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0957, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0888, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0799, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 380/976 [01:47<02:46,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(0.0923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0663, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 381/976 [01:47<02:38,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(0.0846, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0815, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0751, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0656, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 382/976 [01:48<02:43,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(0.0721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0656, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0488, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0465, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 383/976 [01:48<02:45,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(0.0655, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0491, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0413, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0376, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0376, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 384/976 [01:48<02:44,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(0.0548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0362, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0293, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 39%|███▉      | 385/976 [01:48<02:37,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(51.9426, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.0499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.6081, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.6678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.3457, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.6803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.5013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.8637, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.6534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|███▉      | 386/976 [01:49<02:38,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(0.0461, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0471, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0534, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0488, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0467, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 40%|███▉      | 387/976 [01:49<02:40,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(0.0461, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0415, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0386, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0351, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0339, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0305, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 40%|███▉      | 388/976 [01:49<02:34,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(0.0379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0338, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0333, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0308, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0280, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0251, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 40%|███▉      | 389/976 [01:49<02:33,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(53.7776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9632, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7088, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.6358, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.9909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.9267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.0291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.2477, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.2348, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|███▉      | 390/976 [01:50<02:31,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(17.1923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.3696, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.1407, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.4278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.5983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.0437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.0307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.2372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.3800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|████      | 391/976 [01:50<02:38,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(93.7158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.5468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.3725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.6009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.4795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.9538, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(91.2441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.9077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.0946, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|████      | 392/976 [01:50<02:37,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(28.2160, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.9959, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0849, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.9117, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.3706, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.6787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.9622, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.9514, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.6476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|████      | 393/976 [01:50<02:28,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(34.3033, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.4821, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.2511, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.0569, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.2215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.1913, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(16.8354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(74.4560, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.1109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 40%|████      | 394/976 [01:51<02:35,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(40.0771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.9475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3647, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.7751, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(80.1159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9485, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6220, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.1558, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.5803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 40%|████      | 395/976 [01:51<02:39,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(110.9719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(96.6009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.3656, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.4094, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(80.2548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(85.2720, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.5713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.4909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.1218, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 41%|████      | 396/976 [01:51<02:42,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(34.4985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.0990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.4370, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2705, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1490, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.4109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.3013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.9501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 41%|████      | 397/976 [01:52<02:44,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.0963, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1034, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1064, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1091, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1070, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1104, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 41%|████      | 398/976 [01:52<02:41,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(25.0662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(126.5528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.7436, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(112.4636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.5807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.4974, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(74.8824, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.9041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.6938, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 41%|████      | 399/976 [01:52<02:30,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(39.8284, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.4420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.5154, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.2274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.6758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.3313, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.0466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7179, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.6692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 41%|████      | 400/976 [01:52<02:29,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(0.1135, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.5782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1244, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1222, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1217, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.5822, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.5859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.S

 41%|████      | 401/976 [01:53<02:21,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(80.5879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.3365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(124.9417, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(177.1330, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(85.1507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(96.6653, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(133.6153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(116.4604, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(156.8835, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward

 41%|████      | 402/976 [01:53<02:16,  4.20it/s]

torch.Size([])
torch.Size([128])
tensor(52.8554, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(49.4640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1994, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(24.9214, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.1072, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.1832, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.0733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.5914, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 41%|████▏     | 403/976 [01:53<02:17,  4.18it/s]

torch.Size([])
torch.Size([128])
tensor(126.3102, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.7030, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.0503, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(105.3084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(110.8293, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(115.1595, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(115.4059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(90.0462, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(98.6870, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackwar

 41%|████▏     | 404/976 [01:53<02:18,  4.12it/s]

torch.Size([])
torch.Size([128])
tensor(0.1418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1496, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1584, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1580, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1650, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1606, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 41%|████▏     | 405/976 [01:54<02:17,  4.15it/s]

torch.Size([])
torch.Size([128])
tensor(60.8335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.7542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.1716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.0590, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(77.6231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(96.6516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.0512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.5379, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(51.3186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 42%|████▏     | 406/976 [01:54<02:23,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(66.7543, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(94.8195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.8237, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.9746, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.6107, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(149.6749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.1942, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.9581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 42%|████▏     | 407/976 [01:54<02:18,  4.10it/s]

torch.Size([])
torch.Size([128])
tensor(85.7434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(80.2672, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.9319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.2911, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.2699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.8420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(86.9518, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(100.5428, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.0580, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 42%|████▏     | 408/976 [01:54<02:18,  4.09it/s]

torch.Size([])
torch.Size([128])
tensor(82.5091, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.9521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.9892, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(100.7484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.2807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.4057, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.7128, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(119.1245, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.8795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 42%|████▏     | 409/976 [01:55<02:21,  4.01it/s]

torch.Size([])
torch.Size([128])
tensor(81.0560, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.7800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.7866, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(84.2766, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.3153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.6917, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(74.5513, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.9127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(95.4553, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 42%|████▏     | 410/976 [01:55<02:17,  4.12it/s]

torch.Size([])
torch.Size([128])
tensor(48.9617, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.4987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.6151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(78.4783, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.9410, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.6542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.6078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.6986, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.8703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 42%|████▏     | 411/976 [01:55<02:24,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(71.6721, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.7627, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(93.6204, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.7568, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.8725, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(99.3524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.6550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.2775, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.5984, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 42%|████▏     | 412/976 [01:55<02:29,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(67.0186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.4812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.2762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.2007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.6526, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(80.0209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.7912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(77.8431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.3073, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 42%|████▏     | 413/976 [01:56<02:34,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(23.8941, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6385, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.8930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.8041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2484, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.4201, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 42%|████▏     | 414/976 [01:56<02:36,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(68.3890, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(104.1732, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(106.2227, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(105.4589, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(134.3302, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(92.1650, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(103.9450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.0255, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(95.3138, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward

 43%|████▎     | 415/976 [01:56<02:34,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(69.5125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.1924, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.7767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(84.4488, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.0301, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.9184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(116.1459, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.5418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 43%|████▎     | 416/976 [01:56<02:29,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(61.7642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(87.2308, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.9476, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.5100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.4557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(77.4954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.7915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(89.0945, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.6481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 43%|████▎     | 417/976 [01:57<02:27,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(71.5519, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.0899, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(95.6678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.0902, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.1608, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.4841, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(87.5442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.8009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.3191, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 43%|████▎     | 418/976 [01:57<02:26,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(173.5577, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(121.0851, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(130.0355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.5104, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(141.4238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(138.5631, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(130.9501, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(130.8640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(86.9256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackw

 43%|████▎     | 419/976 [01:57<02:31,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(54.4791, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.8252, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(100.6939, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.6776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(90.3989, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.5259, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.3532, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.7935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.9803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 43%|████▎     | 420/976 [01:58<02:34,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(46.5021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.3908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.4039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(77.8718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.8218, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(80.6399, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.2013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.0793, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.9769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 43%|████▎     | 421/976 [01:58<02:32,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(119.3873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(133.5603, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(94.6577, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(185.0797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(157.5076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(112.7318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(139.9541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(121.1997, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(184.7077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackw

 43%|████▎     | 422/976 [01:58<02:28,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(117.8828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(101.8481, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.3334, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.2523, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(100.0529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(91.9686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(131.3864, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.9248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(110.6881, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackwar

 43%|████▎     | 423/976 [01:58<02:28,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(55.4716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.6756, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.6876, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(89.7894, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.2871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.8803, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(86.2994, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.6019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.4045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 43%|████▎     | 424/976 [01:59<02:33,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(91.3551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(87.3883, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.6278, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(125.6158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.5364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(151.8776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.2624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.4908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(109.5575, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0

 44%|████▎     | 425/976 [01:59<02:28,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(150.7216, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(125.2235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.3738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.2703, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(97.9356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.9536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(107.6235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(94.0535, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(96.8125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>

 44%|████▎     | 426/976 [01:59<02:26,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(74.7738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.4253, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.1201, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.5502, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.5812, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.5018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.0635, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(115.1549, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.2640, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 44%|████▍     | 427/976 [01:59<02:30,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(73.7714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.9489, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.2953, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(56.6100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.6836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.3749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(84.9015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(89.8730, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 44%|████▍     | 428/976 [02:00<02:28,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(33.4868, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.8862, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.0956, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.7587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(81.1560, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.9837, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.9962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.7331, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 44%|████▍     | 429/976 [02:00<02:25,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(53.8689, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(78.6627, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.9268, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.6149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.6202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.5828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.7958, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.5624, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4230, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 44%|████▍     | 430/976 [02:00<02:26,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(60.8156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.6351, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.0515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.4381, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.4762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.0772, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.0271, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.9277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.6740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 44%|████▍     | 431/976 [02:00<02:19,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(156.4607, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(60.2650, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.7579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(81.3291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.1920, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(106.6167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(105.5548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(73.5027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.3591, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>

 44%|████▍     | 432/976 [02:01<02:22,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(30.3792, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.6331, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.6343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.1344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.3995, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.4874, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.8356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.9372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.8192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 44%|████▍     | 433/976 [02:01<02:18,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(68.4411, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.2301, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.7375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(110.8039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.0500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(118.7802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.4715, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.5006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(83.0177, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 44%|████▍     | 434/976 [02:01<02:20,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(37.6558, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.6129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.6314, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.1739, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.4945, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.6029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.5756, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.2741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.3231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▍     | 435/976 [02:01<02:23,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(77.4521, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.4224, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.9356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.8680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.1964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(93.7080, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(31.5686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.6878, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.9128, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▍     | 436/976 [02:02<02:17,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(62.2388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.1262, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.7545, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.4724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(76.8422, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.1442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.5664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.5368, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(62.2713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▍     | 437/976 [02:02<02:19,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(72.8508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.8124, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(51.9828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.3818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.1570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(98.6550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.6434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(46.2807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▍     | 438/976 [02:02<02:19,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(93.8458, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.3123, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.2060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(114.3087, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.8860, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(140.6404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.7504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(69.2559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(81.0771, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)

 45%|████▍     | 439/976 [02:02<02:17,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(17.8343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.0621, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.3981, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.9291, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.1014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.3887, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(21.0610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 45%|████▌     | 440/976 [02:03<02:18,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(82.3514, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(94.1770, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(97.9750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.4393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(110.8766, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.1233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(85.3219, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(82.5103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(78.1782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 45%|████▌     | 441/976 [02:03<02:18,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(89.5733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(32.7436, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(55.9200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.8732, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(29.2561, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.3416, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.0511, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(72.5754, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.4462, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▌     | 442/976 [02:03<02:11,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(47.1579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.1877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(81.9697, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.1517, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.0873, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.7325, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(57.1185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.7891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.5970, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▌     | 443/976 [02:04<02:16,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(22.5568, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.0211, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.2959, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.3021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.7394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.8221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0565, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(22.0248, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7175, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 45%|████▌     | 444/976 [02:04<02:22,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(40.2688, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(34.7778, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.8531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.9186, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(27.8541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.8987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2367, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.4726, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 46%|████▌     | 445/976 [02:04<02:26,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(10.4185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.9845, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.4544, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.1197, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0426, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.8728, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
tor

 46%|████▌     | 446/976 [02:04<02:28,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(8.3778, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.2225, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.2538, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6939, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.6240, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0148, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4788, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.

 46%|████▌     | 447/976 [02:05<02:27,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(4.9218, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9560, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6275, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1967, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1119, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▌     | 448/976 [02:05<02:17,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(6.8077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6188, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5784, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2888, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▌     | 449/976 [02:05<02:15,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(9.9082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4102, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8284, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3466, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4389, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4673, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▌     | 450/976 [02:05<02:09,  4.05it/s]

torch.Size([])
torch.Size([128])
tensor(4.2158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4854, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0480, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2458, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7544, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1122, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▌     | 451/976 [02:06<02:18,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(5.4694, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2620, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4971, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8910, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1203, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8599, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2906, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▋     | 452/976 [02:06<02:51,  3.05it/s]

torch.Size([])
torch.Size([128])
tensor(8.8586, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0972, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6503, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8141, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2316, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3292, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6548, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 46%|████▋     | 453/976 [02:07<03:14,  2.69it/s]

torch.Size([])
torch.Size([128])
tensor(5.8664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1281, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4975, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 454/976 [02:07<02:56,  2.95it/s]

torch.Size([])
torch.Size([128])
tensor(4.1451, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6928, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9365, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9178, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9520, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7492, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 455/976 [02:07<02:40,  3.24it/s]

torch.Size([])
torch.Size([128])
tensor(3.7238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8473, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6160, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0504, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7131, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 456/976 [02:07<02:34,  3.36it/s]

torch.Size([])
torch.Size([128])
tensor(7.8280, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8819, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0872, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0836, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4805, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1668, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6616, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 457/976 [02:08<02:29,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(5.9482, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3064, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0398, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3753, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8808, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 458/976 [02:08<02:19,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(5.2644, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8075, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7044, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5252, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7181, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8453, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0099, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 459/976 [02:08<02:18,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(5.6140, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1899, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1363, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9946, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6535, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5401, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5096, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5719, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 460/976 [02:08<02:21,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(6.9574, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7026, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4205, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3826, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 461/976 [02:09<02:24,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(3.2151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3550, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9817, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4572, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4317, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9699, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 462/976 [02:09<02:26,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(1.9247, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1795, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7678, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4944, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7346, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1802, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 47%|████▋     | 463/976 [02:09<02:28,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(6.8898, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3777, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9733, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0979, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5993, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8353, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 464/976 [02:10<02:27,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(3.5468, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6897, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7446, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2297, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3350, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 465/976 [02:10<02:27,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(13.2354, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5030, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5298, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7818, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2483, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0087, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.5535, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Si

 48%|████▊     | 466/976 [02:10<02:20,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(4.9700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4451, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3255, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9912, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8696, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0976, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 467/976 [02:10<02:15,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(9.0789, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.7811, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0729, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.3479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2293, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.8294, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.1787, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.3337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7837, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torc

 48%|████▊     | 468/976 [02:11<02:11,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(1.3155, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3033, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4138, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1391, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0687, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2987, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 469/976 [02:11<02:11,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(7.9772, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7463, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8276, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7682, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4213, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1840, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 470/976 [02:11<02:09,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(6.5319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6884, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7936, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4718, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 471/976 [02:11<02:11,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(7.1319, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0531, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6161, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3460, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1043, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5806, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3177, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 472/976 [02:12<02:11,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(1.1583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0528, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9340, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8359, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8412, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 48%|████▊     | 473/976 [02:12<02:09,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(1.4442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3338, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2034, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2355, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0675, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0233, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▊     | 474/976 [02:12<02:12,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(1.1823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1499, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2503, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1343, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0693, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0249, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0052, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9886, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▊     | 475/976 [02:12<02:09,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(0.9990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0311, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9767, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.9018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8726, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8397, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 476/976 [02:13<02:08,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(0.9283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8464, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8392, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7251, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 477/976 [02:13<02:08,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(0.8377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8960, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7760, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7658, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7713, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7497, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 478/976 [02:13<02:08,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(0.8077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7971, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7740, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7568, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7345, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6901, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7286, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6919, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 479/976 [02:13<02:10,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.7922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7614, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7487, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6890, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6687, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 480/976 [02:14<02:08,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(0.7181, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7287, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6693, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6855, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6814, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6242, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 481/976 [02:14<02:07,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(0.6816, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6982, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6880, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6224, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6214, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 482/976 [02:14<02:07,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(0.6632, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6394, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5675, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5763, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 49%|████▉     | 483/976 [02:15<02:06,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(55.3996, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(41.7071, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.7813, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.1893, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.1328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.9870, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.6926, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.1871, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(47.2551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 50%|████▉     | 484/976 [02:15<02:02,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(0.6246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6305, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6393, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6181, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6182, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6336, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6125, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6136, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|████▉     | 485/976 [02:15<02:07,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(0.5985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5918, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5773, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5619, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5765, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5140, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|████▉     | 486/976 [02:15<02:09,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.5711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5664, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5763, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5569, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5334, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4907, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5092, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4773, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|████▉     | 487/976 [02:16<02:02,  4.00it/s]

torch.Size([])
torch.Size([128])
tensor(0.5601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5345, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5271, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4978, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4862, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4606, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4773, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|█████     | 488/976 [02:16<02:03,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(0.5015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4867, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4824, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4923, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4507, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4609, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4413, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4111, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|█████     | 489/976 [02:16<02:03,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(0.4734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4708, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4454, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4256, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3968, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|█████     | 490/976 [02:16<01:59,  4.05it/s]

torch.Size([])
torch.Size([128])
tensor(69.7433, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.9657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(92.2506, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.5505, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(45.4260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(67.3657, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(91.4694, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(61.4375, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(90.8776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 50%|█████     | 491/976 [02:17<02:02,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(0.4689, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4555, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4479, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4639, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4482, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4830, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 50%|█████     | 492/976 [02:17<01:57,  4.12it/s]

torch.Size([])
torch.Size([128])
tensor(0.4170, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4231, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4286, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4176, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3909, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3825, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 493/976 [02:17<02:02,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(0.4053, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3900, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3675, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3520, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 494/976 [02:17<02:08,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(0.3786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3661, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3688, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3395, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3310, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3140, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 495/976 [02:18<02:12,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(0.3557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3611, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3287, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3362, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2949, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2769, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 496/976 [02:18<02:14,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.3345, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3376, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3064, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2915, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2885, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2738, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2712, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 497/976 [02:18<02:30,  3.18it/s]

torch.Size([])
torch.Size([128])
tensor(0.3410, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3221, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3032, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2964, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2753, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2587, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2602, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2456, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 498/976 [02:19<02:24,  3.30it/s]

torch.Size([])
torch.Size([128])
tensor(0.2865, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2750, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2680, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2425, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2489, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2217, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 499/976 [02:19<02:13,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(0.2567, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2611, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2399, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2317, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2083, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2083, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1940, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████     | 500/976 [02:19<02:07,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(0.2306, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2361, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2158, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2091, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1969, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1920, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1801, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 51%|█████▏    | 501/976 [02:19<02:00,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(52.8127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(40.5579, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(30.0612, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.7566, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.2547, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.8983, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.6974, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.4446, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.9299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 51%|█████▏    | 502/976 [02:19<01:56,  4.05it/s]

torch.Size([])
torch.Size([128])
tensor(54.1077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.5453, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.1053, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(26.5820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.6289, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6147, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.9509, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(43.8652, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 52%|█████▏    | 503/976 [02:20<01:53,  4.17it/s]

torch.Size([])
torch.Size([128])
tensor(0.2366, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2450, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2513, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2536, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2546, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2418, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 504/976 [02:20<02:11,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(0.2323, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2215, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2309, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1933, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1840, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 505/976 [02:20<02:04,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.2150, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1935, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1944, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1760, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1692, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1804, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1663, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1571, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 506/976 [02:21<02:06,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(0.2119, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1823, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1691, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1776, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1603, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1601, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1383, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1414, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 507/976 [02:21<02:07,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(0.1734, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1714, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1585, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1636, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1438, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1328, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1259, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 508/976 [02:21<02:07,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(0.1630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1541, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1477, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1356, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1310, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1223, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1267, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1182, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 509/976 [02:21<02:09,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(0.1448, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1303, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1232, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1097, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1059, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0916, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 510/976 [02:22<02:17,  3.38it/s]

torch.Size([])
torch.Size([128])
tensor(0.1264, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1207, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1172, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0962, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0931, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0906, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0807, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 511/976 [02:22<02:13,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(0.1168, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1062, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0925, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0902, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0850, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0800, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 52%|█████▏    | 512/976 [02:22<02:12,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(0.1016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0901, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0878, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0851, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0798, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0749, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0633, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0642, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 513/976 [02:23<02:19,  3.31it/s]

torch.Size([])
torch.Size([128])
tensor(0.0908, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0752, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0784, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0700, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0673, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0645, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0592, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0523, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0493, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 514/976 [02:23<02:18,  3.34it/s]

torch.Size([])
torch.Size([128])
tensor(0.0727, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0722, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0665, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0576, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0525, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0467, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0442, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 515/976 [02:23<02:12,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(0.0625, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0606, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0510, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0469, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0429, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0395, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0318, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 516/976 [02:24<02:13,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(0.0524, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0475, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0420, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0377, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0335, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0286, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 517/976 [02:24<02:10,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(0.0441, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0326, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0300, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0277, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0246, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0205, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 518/976 [02:24<02:06,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(0.0339, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0326, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0239, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0214, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0196, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0171, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0152, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 519/976 [02:24<02:09,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(0.0287, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0238, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0228, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0206, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0143, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0122, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0106, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 520/976 [02:25<02:07,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(0.0218, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0193, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0169, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0144, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0137, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0100, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0073, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 53%|█████▎    | 521/976 [02:25<02:09,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(41.3581, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.0177, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.5265, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(28.5378, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(42.8372, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(36.3434, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.7667, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5495, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.5607, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
to

 53%|█████▎    | 522/976 [02:25<02:12,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(0.0169, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0182, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0192, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0205, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0195, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0202, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0199, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0191, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▎    | 523/976 [02:25<02:09,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(0.0167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0148, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0127, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0116, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0103, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0098, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0077, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0065, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0056, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▎    | 524/976 [02:26<02:04,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(0.0114, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0109, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0093, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0089, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0072, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0050, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0041, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0033, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▍    | 525/976 [02:26<02:04,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(0.0209, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0118, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0129, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0183, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0073, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0130, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0099, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0067, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▍    | 526/976 [02:26<02:05,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(0.0061, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0060, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0049, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0046, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0031, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▍    | 527/976 [02:27<02:06,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(117.4786, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(19.9847, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(99.7582, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(84.8828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(75.0544, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.8136, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.8235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(93.4777, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(97.3990, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 54%|█████▍    | 528/976 [02:27<02:05,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(70.6963, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.3444, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(64.8679, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(66.0159, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6364, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.0307, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(101.6684, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(39.8404, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(88.3439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)


 54%|█████▍    | 529/976 [02:27<02:03,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(0.3801, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3863, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4383, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.8424, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3577, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6716, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1293, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7621, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 54%|█████▍    | 530/976 [02:27<01:57,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(104.6432, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(126.8349, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(96.3200, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(104.2653, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(99.2310, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(103.5032, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(138.4410, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(79.7628, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(108.1460, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackwar

 54%|█████▍    | 531/976 [02:28<01:56,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(0.3638, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2486, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2623, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1515, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1922, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1789, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1562, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1881, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1676, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▍    | 532/976 [02:28<02:04,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(60.2879, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(63.9076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(71.1937, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(68.9134, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(70.8163, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(54.3633, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(65.6151, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.2709, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.8820, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 55%|█████▍    | 533/976 [02:28<01:57,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(0.3389, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1985, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2843, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2142, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2169, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▍    | 534/976 [02:28<01:58,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(28.7201, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(52.5027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.7443, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(44.7671, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8997, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(38.0337, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(48.7149, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(33.2762, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.2557, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 55%|█████▍    | 535/976 [02:29<01:57,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(0.7269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.7274, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4529, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.6598, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2440, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.3368, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.5759, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2745, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.4782, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▍    | 536/976 [02:29<02:02,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(14.2634, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(14.7299, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.1508, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(25.3076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.1315, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.8648, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.9273, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.8670, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.5538, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 55%|█████▌    | 537/976 [02:29<02:02,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(3.9304, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8630, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3260, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4954, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9383, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1099, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2113, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1631, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8431, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▌    | 538/976 [02:30<01:59,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(8.1852, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7841, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6296, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8417, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2570, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0600, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2061, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▌    | 539/976 [02:30<01:55,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(6.5018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0575, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8437, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2290, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7981, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3185, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8342, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1711, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▌    | 540/976 [02:30<01:54,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(1.9930, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4877, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1341, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7439, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7045, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9551, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3235, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6639, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6453, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 55%|█████▌    | 541/976 [02:30<01:54,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(2.5941, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3662, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3841, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2891, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9351, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1686, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9423, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2250, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1040, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▌    | 542/976 [02:31<01:57,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(28.1153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(50.1653, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(35.6677, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(58.2449, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(53.6973, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(20.8646, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(37.8369, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(59.4605, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(18.8428, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 56%|█████▌    | 543/976 [02:31<01:50,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(2.3374, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7965, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2080, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9817, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6287, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7189, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2500, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6266, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2384, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▌    | 544/976 [02:31<01:54,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(3.0794, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4512, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0925, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1900, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2422, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7212, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2452, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5241, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▌    | 545/976 [02:31<01:57,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(10.5655, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8473, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(10.0542, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.1150, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3610, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8583, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6048, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0761, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.

 56%|█████▌    | 546/976 [02:32<01:59,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(13.2828, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.3205, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(15.1797, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(23.6258, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.9741, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(13.3197, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(12.7859, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(17.9116, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(11.8076, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
t

 56%|█████▌    | 547/976 [02:32<02:02,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(0.0189, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0272, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0283, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0382, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0245, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0326, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0410, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0286, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0263, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▌    | 548/976 [02:32<02:00,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(0.1724, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2153, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2388, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2126, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1408, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1790, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1455, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1144, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▋    | 549/976 [02:33<01:53,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(0.1778, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.2853, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1976, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1516, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1422, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1035, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1312, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.1184, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▋    | 550/976 [02:33<01:51,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(0.0297, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0276, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0344, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0167, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0226, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0190, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0269, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0156, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0189, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 56%|█████▋    | 551/976 [02:33<01:48,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(0.0104, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0078, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0061, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0053, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0038, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 552/976 [02:33<01:44,  4.04it/s]

torch.Size([])
torch.Size([128])
tensor(0.0133, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0101, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0090, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0094, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0082, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0066, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0084, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0047, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0051, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 553/976 [02:33<01:44,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(0.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0054, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0047, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0031, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0027, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0019, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 554/976 [02:34<01:44,  4.03it/s]

torch.Size([])
torch.Size([128])
tensor(0.0039, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0037, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0029, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0022, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0012, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 555/976 [02:34<01:49,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 556/976 [02:34<01:51,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(0.0018, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0014, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0012, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 557/976 [02:35<01:52,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(0.0033, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0015, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0012, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0023, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0021, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0017, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 558/976 [02:35<01:49,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 559/976 [02:35<01:45,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 57%|█████▋    | 560/976 [02:35<01:48,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9926e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)

 57%|█████▋    | 561/976 [02:36<01:51,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9297e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)

 58%|█████▊    | 562/976 [02:36<01:52,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8286e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4381e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6405e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5969e-05, device='cuda:0', dtype=torch.float6

 58%|█████▊    | 563/976 [02:36<01:48,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0003, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 58%|█████▊    | 564/976 [02:36<01:51,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9222e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8259e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5491e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2378e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1864e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2497e-05, device='cuda:

 58%|█████▊    | 565/976 [02:37<01:51,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8925e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8772e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0519e-05, device='cuda:0', dtype=torch.float64,
       g

 58%|█████▊    | 566/976 [02:37<01:49,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(6.8291e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5730e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2322e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1567e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8368e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5245e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1639e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7232e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 58%|█████▊    | 567/976 [02:37<01:52,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(6.7407e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9799e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4138e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0003e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2298e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6331e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5543e-05, de

 58%|█████▊    | 568/976 [02:38<01:49,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(4.6361e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5378e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0651e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2610e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1321e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3168e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3711e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4117e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 58%|█████▊    | 569/976 [02:38<01:48,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(3.6291e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7732e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0266e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6943e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0654e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7043e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8212e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0394e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 58%|█████▊    | 570/976 [02:38<01:47,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(1.6219e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4620e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2536e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3575e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1606e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0510e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3932e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7129e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▊    | 571/976 [02:38<01:48,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(6.5834e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5319e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0331e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2862e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9879e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9394e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7886e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6343e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▊    | 572/976 [02:39<01:53,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(2.5326e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4683e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1138e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8202e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4043e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3271e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9115e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9043e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▊    | 573/976 [02:39<01:49,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(6.5835e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1957e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2879e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1601e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5113e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0871e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9201e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0

 59%|█████▉    | 574/976 [02:39<01:48,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(2.8408e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9866e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2008e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5860e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9520e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3037e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6739e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 575/976 [02:39<01:49,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(2.6716e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0925e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3440e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7046e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8741e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7766e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3156e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2562e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 576/976 [02:40<01:52,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0002, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(0.0001, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>)
torch.Size

 59%|█████▉    | 577/976 [02:40<01:53,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(1.4121e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7030e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3631e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9722e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4741e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7390e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1387e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0165e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 578/976 [02:40<01:51,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.6743e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8426e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9871e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4780e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9623e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2548e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5963e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3481e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 579/976 [02:41<01:48,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(4.5923e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9823e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9217e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3565e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0692e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0809e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5881e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9870e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 59%|█████▉    | 580/976 [02:41<01:51,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(9.5078e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7011e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7950e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5218e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0264e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0946e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9667e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9776e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|█████▉    | 581/976 [02:41<01:49,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(2.2122e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1503e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1581e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2674e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4203e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0409e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6738e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5270e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|█████▉    | 582/976 [02:41<01:56,  3.38it/s]

torch.Size([])
torch.Size([128])
tensor(2.0801e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0140e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3264e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2546e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6002e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2872e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1176e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6705e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|█████▉    | 583/976 [02:42<01:51,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.4280e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8409e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9227e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7269e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5479e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6795e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8371e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|█████▉    | 584/976 [02:42<01:47,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(1.1146e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7249e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9396e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0472e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2152e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2809e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1497e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|█████▉    | 585/976 [02:42<01:50,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(9.5046e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2082e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4906e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8100e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2987e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7287e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7072e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4182e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|██████    | 586/976 [02:43<01:51,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(1.3641e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2966e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2194e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7741e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1035e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9059e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5685e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2331e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|██████    | 587/976 [02:43<01:52,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(8.6645e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6392e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0390e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8600e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1118e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8722e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5762e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1065e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|██████    | 588/976 [02:43<01:53,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(5.4655e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8416e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0821e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0640e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9294e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5368e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9047e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1071e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|██████    | 589/976 [02:43<01:46,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(9.5705e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0285e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9751e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9163e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1413e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6652e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2679e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6500e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 60%|██████    | 590/976 [02:44<01:45,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(3.4615e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2840e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0398e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8349e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1201e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4426e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5907e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4330e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 591/976 [02:44<01:39,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(5.8652e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8693e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4681e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5780e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7349e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7605e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5343e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3688e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 592/976 [02:44<01:42,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(2.6142e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0383e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2723e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4483e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3622e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6843e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2001e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 593/976 [02:44<01:44,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(1.1363e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0004e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2832e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2907e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0295e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9608e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0455e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6555e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 594/976 [02:45<01:48,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(9.7986e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8153e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7347e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0792e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9311e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4587e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1070e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7730e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 595/976 [02:45<01:48,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(3.9018e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0110e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8419e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4979e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7777e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9727e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 596/976 [02:45<01:48,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(8.8116e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2940e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4452e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8543e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8980e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1814e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2134e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████    | 597/976 [02:46<01:48,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(1.0626e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3472e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3714e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3345e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2355e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4336e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1279e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1594e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████▏   | 598/976 [02:46<01:46,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(8.0905e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2711e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4904e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2049e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8310e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1610e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8244e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8154e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████▏   | 599/976 [02:46<01:40,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(7.9265e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6321e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2468e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4899e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9215e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7728e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9015e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0047e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 61%|██████▏   | 600/976 [02:46<01:38,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(5.6203e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9447e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4159e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3568e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6904e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2561e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 601/976 [02:47<01:36,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(1.3641e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1971e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1968e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2823e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1910e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1034e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4406e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2678e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 602/976 [02:47<01:33,  4.00it/s]

torch.Size([])
torch.Size([128])
tensor(9.3925e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3842e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1712e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9381e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6658e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0899e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6846e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 603/976 [02:47<01:33,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(6.4103e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6248e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2747e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3173e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0132e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3448e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 604/976 [02:47<01:33,  3.98it/s]

torch.Size([])
torch.Size([128])
tensor(2.5076e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2882e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5760e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4061e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6423e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4923e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8529e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7601e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 605/976 [02:48<01:36,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(2.4395e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1712e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9001e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5812e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6148e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4070e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2358e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4839e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 606/976 [02:48<01:38,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(1.0539e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5806e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2831e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9170e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8037e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0144e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4788e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7911e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 607/976 [02:48<01:38,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(1.7890e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0359e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8556e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1527e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2367e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5940e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1007e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5463e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 608/976 [02:48<01:35,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(9.1524e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3803e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2040e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1643e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4594e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0977e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0624e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▏   | 609/976 [02:49<01:32,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(3.5883e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1033e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8416e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2888e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6012e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9285e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9481e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 62%|██████▎   | 610/976 [02:49<01:36,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(7.3886e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1653e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1088e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1653e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1469e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9957e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1555e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6805e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 611/976 [02:49<01:39,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(4.9054e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0025e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4432e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8554e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4168e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1463e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1443e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 612/976 [02:50<01:41,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(6.7386e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0954e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7141e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7279e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9247e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4845e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5534e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 613/976 [02:50<01:43,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(2.9548e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7570e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0109e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1094e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9435e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6808e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4482e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7452e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 614/976 [02:50<01:42,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.4087e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6229e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6310e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9427e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2252e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5950e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9591e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0271e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 615/976 [02:50<01:40,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(4.0498e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9208e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9490e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5698e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9974e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7718e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9679e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 616/976 [02:51<01:50,  3.25it/s]

torch.Size([])
torch.Size([128])
tensor(9.9199e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8601e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2563e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8831e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7618e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1118e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2654e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1277e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 617/976 [02:51<02:08,  2.79it/s]

torch.Size([])
torch.Size([128])
tensor(4.5848e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3702e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8844e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1636e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9342e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5992e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6186e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 618/976 [02:52<02:22,  2.50it/s]

torch.Size([])
torch.Size([128])
tensor(4.1420e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4524e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6472e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3842e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8880e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9048e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4592e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 63%|██████▎   | 619/976 [02:52<02:33,  2.33it/s]

torch.Size([])
torch.Size([128])
tensor(4.3676e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6751e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7401e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1156e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1532e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9027e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1586e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4620e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▎   | 620/976 [02:53<03:01,  1.96it/s]

torch.Size([])
torch.Size([128])
tensor(6.4948e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8406e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2296e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5161e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7569e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1798e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9256e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9359e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▎   | 621/976 [02:53<02:34,  2.30it/s]

torch.Size([])
torch.Size([128])
tensor(8.3741e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4605e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7179e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0515e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2216e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7376e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5005e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9263e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▎   | 622/976 [02:53<02:14,  2.64it/s]

torch.Size([])
torch.Size([128])
tensor(3.6942e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2166e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6955e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2581e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7152e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3165e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2363e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 623/976 [02:54<02:00,  2.93it/s]

torch.Size([])
torch.Size([128])
tensor(2.4890e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2128e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8141e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1194e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1899e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9073e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1751e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3651e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 624/976 [02:54<01:50,  3.20it/s]

torch.Size([])
torch.Size([128])
tensor(1.1085e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3425e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5639e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4921e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5719e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1731e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3115e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5742e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 625/976 [02:54<01:43,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(3.8384e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5198e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6956e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1745e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6815e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1164e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7978e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6249e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 626/976 [02:54<01:38,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(4.9562e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1856e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2539e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8783e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6356e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0921e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5318e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 627/976 [02:55<01:35,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(4.5029e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8687e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4709e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4521e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7695e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4788e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6387e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5985e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 628/976 [02:55<01:31,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(7.5907e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0369e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2521e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5379e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2622e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1960e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6019e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3046e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 64%|██████▍   | 629/976 [02:55<01:33,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(4.2371e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4701e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5477e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4707e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2800e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9197e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5615e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0677e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▍   | 630/976 [02:55<01:29,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(2.8801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2324e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1403e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3609e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0178e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0797e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9262e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▍   | 631/976 [02:56<01:27,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(4.6719e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0707e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4139e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8587e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3452e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4784e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0590e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6236e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▍   | 632/976 [02:56<01:26,  3.98it/s]

torch.Size([])
torch.Size([128])
tensor(8.0966e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0112e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7447e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0275e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4916e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6416e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6563e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1375e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▍   | 633/976 [02:56<01:29,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(1.0150e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3450e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0298e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8971e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1426e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2406e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3791e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4610e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▍   | 634/976 [02:56<01:25,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(4.0898e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4690e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4275e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9396e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9019e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2164e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▌   | 635/976 [02:57<01:29,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(6.0395e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9856e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7816e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0884e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0298e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4386e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6822e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5454e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▌   | 636/976 [02:57<01:31,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(7.3640e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0972e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2874e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1995e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9081e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2736e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▌   | 637/976 [02:57<01:31,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(2.1327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7890e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9456e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4456e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7523e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3157e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0413e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▌   | 638/976 [02:58<01:27,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(3.3183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3833e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8761e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9763e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3627e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4119e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7517e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5371e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 65%|██████▌   | 639/976 [02:58<01:29,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(7.6270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9640e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3110e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1840e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0043e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0296e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 640/976 [02:58<01:26,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(1.2557e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7552e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0980e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9842e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6825e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2053e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0066e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0577e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 641/976 [02:58<01:24,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(7.4690e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1938e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6719e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3607e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1964e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4399e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2674e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 642/976 [02:59<01:24,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(4.3900e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8760e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2544e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6883e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3787e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3491e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8519e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 643/976 [02:59<01:28,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(1.0802e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0687e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4711e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7728e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7498e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4103e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 644/976 [02:59<01:30,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(5.1051e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2604e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4110e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9425e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1131e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9870e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8983e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7242e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 645/976 [02:59<01:32,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.7465e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3511e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5239e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5014e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4554e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4638e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5075e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9891e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▌   | 646/976 [03:00<01:33,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.4555e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3097e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2022e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8625e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3476e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7592e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▋   | 647/976 [03:00<01:32,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(4.6038e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7545e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2428e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9357e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4071e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0300e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▋   | 648/976 [03:00<01:25,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(6.0067e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8705e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4077e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5964e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8059e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7954e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0603e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3715e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 66%|██████▋   | 649/976 [03:00<01:24,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(3.1713e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9976e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1762e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9666e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7037e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2550e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6871e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 650/976 [03:01<01:20,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(1.9915e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4940e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3197e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6712e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6750e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1160e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5769e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1497e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 651/976 [03:01<01:17,  4.21it/s]

torch.Size([])
torch.Size([128])
tensor(5.7002e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3385e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5186e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2636e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8897e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3865e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9013e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5234e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 652/976 [03:01<01:15,  4.30it/s]

torch.Size([])
torch.Size([128])
tensor(7.1125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6101e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2709e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1084e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1112e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6031e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7893e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7372e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 653/976 [03:01<01:13,  4.38it/s]

torch.Size([])
torch.Size([128])
tensor(2.1528e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4983e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9599e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4366e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4488e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6877e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4344e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 654/976 [03:02<01:15,  4.25it/s]

torch.Size([])
torch.Size([128])
tensor(3.1471e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4805e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8961e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0098e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5073e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3020e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4569e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 655/976 [03:02<01:15,  4.25it/s]

torch.Size([])
torch.Size([128])
tensor(5.1525e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5319e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8429e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5216e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5044e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9589e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7293e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8384e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 656/976 [03:02<01:19,  4.05it/s]

torch.Size([])
torch.Size([128])
tensor(4.1516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5491e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5306e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4445e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7154e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8494e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6444e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5906e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 657/976 [03:02<01:16,  4.18it/s]

torch.Size([])
torch.Size([128])
tensor(1.9394e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5861e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2017e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3565e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5677e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4076e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2041e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9401e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 67%|██████▋   | 658/976 [03:03<01:18,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(3.3019e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6987e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0897e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9436e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9917e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4122e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 659/976 [03:03<01:19,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(5.2895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3709e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0487e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1253e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7679e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8837e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4844e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1857e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 660/976 [03:03<01:17,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(2.4542e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3210e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3884e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7350e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2408e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3561e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1024e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 661/976 [03:03<01:18,  4.00it/s]

torch.Size([])
torch.Size([128])
tensor(2.0830e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8632e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3940e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9356e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0351e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2454e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5224e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 662/976 [03:04<01:21,  3.85it/s]

torch.Size([])
torch.Size([128])
tensor(3.1080e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1657e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0048e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0794e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3597e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6755e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 663/976 [03:04<01:18,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(6.6311e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2137e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0855e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7064e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7025e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1112e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2197e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0207e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 664/976 [03:04<01:20,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(2.7647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9466e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4081e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5972e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7992e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7076e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8822e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7077e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 665/976 [03:04<01:20,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(3.5721e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9096e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8487e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6736e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6104e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0560e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4951e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 666/976 [03:05<01:18,  3.93it/s]

torch.Size([])
torch.Size([128])
tensor(2.6512e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8176e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6529e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6300e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0572e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5285e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 667/976 [03:05<01:19,  3.89it/s]

torch.Size([])
torch.Size([128])
tensor(5.6488e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7086e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9879e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4310e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0931e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1131e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5251e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 68%|██████▊   | 668/976 [03:05<01:18,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(4.3147e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2565e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6730e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2042e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7435e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6916e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5386e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5447e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▊   | 669/976 [03:05<01:22,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(5.2001e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2958e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4778e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5324e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8536e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1820e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3627e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▊   | 670/976 [03:06<01:25,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(3.1286e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0054e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9536e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5859e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2202e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3362e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6911e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8718e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 671/976 [03:06<01:25,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(3.8160e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3122e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4984e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3039e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5686e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9720e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3915e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0701e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 672/976 [03:06<01:26,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(2.9612e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3154e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4035e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8245e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2565e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6181e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 673/976 [03:07<01:25,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(3.1717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2902e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7048e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6692e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0817e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6382e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8282e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 674/976 [03:07<01:21,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(4.3942e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1927e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4440e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2418e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9997e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9891e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6085e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3605e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 675/976 [03:07<01:22,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(5.8872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5712e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2901e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3920e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4495e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6146e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 676/976 [03:07<01:18,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(1.4638e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2425e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4098e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0577e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5608e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8299e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 677/976 [03:08<01:19,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(1.9824e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6056e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9605e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8652e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2550e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7865e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5465e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 69%|██████▉   | 678/976 [03:08<01:18,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(3.9749e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9340e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0866e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6246e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6526e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7189e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2694e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4138e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|██████▉   | 679/976 [03:08<01:17,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(5.6438e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3555e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0605e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4278e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5902e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9445e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9927e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|██████▉   | 680/976 [03:08<01:18,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.6166e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6112e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1208e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9411e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1162e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7766e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1307e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|██████▉   | 681/976 [03:09<01:17,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(1.8248e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7907e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7613e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9435e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2031e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2203e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2530e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7346e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|██████▉   | 682/976 [03:09<01:18,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(3.1316e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7989e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1194e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0366e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0788e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8060e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3615e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2566e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|██████▉   | 683/976 [03:09<01:19,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(5.0326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2598e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4598e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9124e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7740e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3334e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8524e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7837e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|███████   | 684/976 [03:09<01:16,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(2.2253e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9543e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8552e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7736e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2804e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8754e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2766e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4490e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|███████   | 685/976 [03:10<01:17,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(2.9265e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6321e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9082e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1525e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0514e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1515e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|███████   | 686/976 [03:10<01:18,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(1.9458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7630e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0095e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9385e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8403e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8897e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4577e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|███████   | 687/976 [03:10<01:18,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(3.8769e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6935e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4675e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4772e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9191e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7108e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9020e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0147e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 70%|███████   | 688/976 [03:11<01:15,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(1.0528e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0485e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1018e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2586e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1552e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4420e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3569e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0709e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 689/976 [03:11<01:16,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.8698e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0730e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0102e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9938e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8543e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8170e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2210e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5196e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 690/976 [03:11<01:15,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(9.4292e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7546e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1094e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0144e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9694e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1646e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1618e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 691/976 [03:11<01:15,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(5.7219e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6946e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8447e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8302e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6261e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8896e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9218e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 692/976 [03:12<01:17,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(1.2556e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2648e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3726e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2239e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1524e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3882e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1547e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 693/976 [03:12<01:17,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(2.0681e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4564e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3235e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0325e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0402e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5394e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3498e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 694/976 [03:12<01:13,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(4.3574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6614e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1126e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8863e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3520e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8343e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7273e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1512e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████   | 695/976 [03:12<01:15,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(7.6152e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0279e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0030e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9426e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3354e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7620e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████▏  | 696/976 [03:13<01:17,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(9.4117e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6167e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3758e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8955e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6889e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2927e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 71%|███████▏  | 697/976 [03:13<01:18,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(2.0785e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6761e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6541e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9353e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5668e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9303e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7702e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 698/976 [03:13<01:17,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.0034e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9763e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2224e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4720e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0710e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2156e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0582e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 699/976 [03:14<01:13,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(3.0123e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3747e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8206e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9189e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5389e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7238e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0281e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4507e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 700/976 [03:14<01:11,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(1.7424e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1115e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3030e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9365e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1011e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2503e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4053e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3736e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 701/976 [03:14<01:09,  3.96it/s]

torch.Size([])
torch.Size([128])
tensor(2.8600e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2245e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0481e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7925e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3993e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1049e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8865e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9427e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 702/976 [03:14<01:07,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(3.4668e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5032e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0068e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9522e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3590e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9551e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8009e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 703/976 [03:14<01:04,  4.21it/s]

torch.Size([])
torch.Size([128])
tensor(2.9313e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1527e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6643e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4079e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2960e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6638e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7257e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 704/976 [03:15<01:06,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(1.9713e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8455e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9628e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3657e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6995e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0507e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0870e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3614e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 705/976 [03:15<01:04,  4.20it/s]

torch.Size([])
torch.Size([128])
tensor(2.0183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5236e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3203e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3561e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5303e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5690e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7629e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2616e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 706/976 [03:15<01:06,  4.07it/s]

torch.Size([])
torch.Size([128])
tensor(4.8216e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9058e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7769e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1112e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6635e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8582e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3644e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 72%|███████▏  | 707/976 [03:15<01:05,  4.14it/s]

torch.Size([])
torch.Size([128])
tensor(1.4375e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0167e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1406e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1672e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7544e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3130e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8808e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8825e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 708/976 [03:16<01:04,  4.17it/s]

torch.Size([])
torch.Size([128])
tensor(1.6750e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7262e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1961e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4923e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3017e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8662e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5488e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3502e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 709/976 [03:16<01:07,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(2.9268e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4939e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9848e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1559e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5084e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4199e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2771e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1220e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 710/976 [03:16<01:10,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(7.6800e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1943e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3573e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6713e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1255e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3504e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9739e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0199e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 711/976 [03:17<01:11,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(3.1351e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5291e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9406e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3740e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6848e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2329e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8343e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 712/976 [03:17<01:07,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(1.8654e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3623e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8568e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0118e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1329e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2640e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8187e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8389e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 713/976 [03:17<01:09,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.3201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3263e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3789e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4654e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4298e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1798e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 714/976 [03:17<01:10,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(1.3432e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5190e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3964e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7822e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1875e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 715/976 [03:18<01:09,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(1.8403e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7667e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6212e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5092e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4534e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7030e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9092e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7477e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 716/976 [03:18<01:08,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(1.6773e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5063e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4759e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4652e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5930e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3734e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6633e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2921e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 73%|███████▎  | 717/976 [03:18<01:08,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(1.6528e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8184e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6279e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6377e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4078e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6946e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8392e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5603e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▎  | 718/976 [03:18<01:07,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(5.8084e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2205e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7500e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4728e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3190e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9340e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8266e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7958e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▎  | 719/976 [03:19<01:09,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(1.3386e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2411e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0934e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3773e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2475e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2077e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5097e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1928e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 720/976 [03:19<01:10,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.1257e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7944e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5564e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7799e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1596e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0660e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4660e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 721/976 [03:19<01:10,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(4.2590e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1981e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6262e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8984e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6492e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5867e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7887e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2586e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 722/976 [03:19<01:07,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(4.5912e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1886e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8228e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9254e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1846e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8929e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1192e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0024e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 723/976 [03:20<01:07,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.0872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0882e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5914e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1473e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2267e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3982e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1836e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2395e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 724/976 [03:20<01:05,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(1.1115e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2381e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0800e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4666e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6643e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0575e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2901e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 725/976 [03:20<01:06,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(4.5451e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9936e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3170e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1321e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2368e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6182e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4914e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 726/976 [03:20<01:04,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(5.5747e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2754e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2321e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0102e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9451e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7289e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7448e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 74%|███████▍  | 727/976 [03:21<01:04,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(3.5250e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9693e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9981e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7239e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8580e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5740e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0623e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7075e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▍  | 728/976 [03:21<01:04,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(2.6133e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8189e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6158e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4541e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1722e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2771e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5480e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5344e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▍  | 729/976 [03:21<01:04,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(2.6101e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7622e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7072e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4711e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8700e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6941e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4927e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▍  | 730/976 [03:22<01:04,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(3.6911e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5261e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0944e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8914e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3242e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8804e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2533e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▍  | 731/976 [03:22<01:02,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(2.1187e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4543e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8708e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8475e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9121e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 732/976 [03:22<01:03,  3.83it/s]

torch.Size([])
torch.Size([128])
tensor(2.0124e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4135e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1456e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2397e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3092e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0418e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2754e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 733/976 [03:22<01:05,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(4.4007e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9341e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4236e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9940e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6906e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9454e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2750e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2275e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 734/976 [03:23<01:07,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(6.0222e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1066e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0332e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5019e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8039e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7328e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8606e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7730e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 735/976 [03:23<01:09,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(1.3211e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5661e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3493e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2142e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3046e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4222e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2460e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5220e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 75%|███████▌  | 736/976 [03:23<01:14,  3.23it/s]

torch.Size([])
torch.Size([128])
tensor(1.4463e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1932e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0462e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0870e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2738e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3121e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 737/976 [03:24<01:08,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(4.0519e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3771e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4902e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8264e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4594e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6721e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4150e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 738/976 [03:24<01:07,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(4.3656e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6126e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2772e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1059e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6265e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8472e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 739/976 [03:24<01:05,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(1.0853e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1275e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0051e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6142e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3218e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0355e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0790e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3480e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 740/976 [03:24<01:01,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(1.5510e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4830e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9923e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8864e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3302e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4155e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5950e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 741/976 [03:25<01:01,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(3.0333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2219e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5436e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3136e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8774e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 742/976 [03:25<01:02,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(2.7421e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6767e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6065e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6612e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2149e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8282e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4847e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 743/976 [03:25<00:59,  3.92it/s]

torch.Size([])
torch.Size([128])
tensor(1.5132e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3510e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7254e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4361e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6966e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2097e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6581e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4318e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▌  | 744/976 [03:25<01:01,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(1.1267e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6583e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1568e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6888e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6560e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3681e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6232e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▋  | 745/976 [03:26<01:02,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(3.1446e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5820e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7851e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2755e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8247e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3867e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6132e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 76%|███████▋  | 746/976 [03:26<01:03,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(1.7854e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7080e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5424e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0149e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5560e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5561e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 747/976 [03:26<01:05,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(1.3226e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2566e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7377e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1858e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4438e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4266e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1465e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2842e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 748/976 [03:26<01:02,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.0952e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3685e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9674e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8820e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0449e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4879e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5128e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6945e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 749/976 [03:27<00:59,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(4.9205e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5611e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1526e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1957e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6029e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9856e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9045e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 750/976 [03:27<00:58,  3.87it/s]

torch.Size([])
torch.Size([128])
tensor(2.1705e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6351e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3127e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3345e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1948e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8607e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1043e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3389e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 751/976 [03:27<00:57,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(1.9125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4452e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9418e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3652e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0294e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0463e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8118e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6033e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 752/976 [03:27<00:55,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(1.5467e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6120e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5453e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9794e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7246e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5580e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7257e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 753/976 [03:28<00:55,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(4.7095e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3963e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8164e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3526e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5899e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8337e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 754/976 [03:28<00:55,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(1.7903e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7917e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4609e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6157e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6345e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7520e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6265e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6119e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 755/976 [03:28<00:57,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(1.5826e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0760e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3943e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7903e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4471e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4739e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1214e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7408e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 77%|███████▋  | 756/976 [03:29<01:05,  3.38it/s]

torch.Size([])
torch.Size([128])
tensor(1.6412e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3250e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3612e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1295e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5062e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3601e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3858e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 757/976 [03:29<01:03,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(4.2806e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6740e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5889e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0708e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4526e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3251e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9263e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 758/976 [03:29<00:59,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(1.5816e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6154e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6256e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1319e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4139e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6306e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3780e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5925e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 759/976 [03:29<01:01,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(1.2488e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3390e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8103e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3079e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0035e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8136e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1410e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 760/976 [03:30<00:59,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(2.1691e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6373e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5718e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1459e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2376e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1268e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4994e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8167e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 761/976 [03:30<00:59,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.6260e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3589e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0735e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3404e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5947e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9375e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9091e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5919e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 762/976 [03:30<01:02,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(1.1492e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3186e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5566e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8205e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7568e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4567e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3475e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2391e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 763/976 [03:31<01:00,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.9197e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0815e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3176e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0812e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7056e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9920e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9081e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 764/976 [03:31<00:59,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(1.4453e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1316e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7306e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9129e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7138e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8305e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 765/976 [03:31<00:58,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(6.0399e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1782e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8910e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3845e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6655e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4569e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3353e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 78%|███████▊  | 766/976 [03:31<00:57,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(1.7449e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5228e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4981e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5107e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8188e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4651e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3537e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6302e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▊  | 767/976 [03:32<00:55,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(8.3381e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6637e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0859e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4313e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5014e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9101e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5065e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0450e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▊  | 768/976 [03:32<00:55,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(2.1839e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9122e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5219e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4218e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1891e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7950e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0941e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 769/976 [03:32<00:54,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(7.4801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0416e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0726e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2461e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0018e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5439e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1700e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8314e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 770/976 [03:32<00:56,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(1.4417e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5039e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9876e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1607e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9712e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4012e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6234e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 771/976 [03:33<00:57,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.2843e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7289e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2260e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4985e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5737e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4783e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1965e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 772/976 [03:33<00:59,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(1.6735e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8096e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9422e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4487e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4348e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8142e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6457e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 773/976 [03:33<00:59,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(1.7076e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5953e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4450e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5644e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3779e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5694e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3473e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9976e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 774/976 [03:34<00:59,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(1.2395e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1263e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0953e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1065e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2425e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1411e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1787e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0069e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 79%|███████▉  | 775/976 [03:34<00:59,  3.38it/s]

torch.Size([])
torch.Size([128])
tensor(1.6203e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6377e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3925e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6903e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7208e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2913e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9268e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6231e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|███████▉  | 776/976 [03:34<00:58,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(2.1714e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9677e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3245e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1519e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2662e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8828e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8205e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5917e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|███████▉  | 777/976 [03:34<00:57,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(2.7188e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2352e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8907e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5694e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7141e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7454e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5633e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4673e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|███████▉  | 778/976 [03:35<00:56,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(8.5350e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0872e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0166e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2530e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0596e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0164e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2172e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|███████▉  | 779/976 [03:35<00:54,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.4676e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4312e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6271e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0582e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1354e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1184e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|███████▉  | 780/976 [03:35<00:54,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(2.1377e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5493e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4242e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2847e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7129e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4077e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4695e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|████████  | 781/976 [03:36<00:54,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.1450e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2881e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9005e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7248e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6645e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8302e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7269e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0719e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|████████  | 782/976 [03:36<00:52,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(1.5875e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3667e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4502e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3202e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3276e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6544e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3463e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3455e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|████████  | 783/976 [03:36<00:52,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(1.1056e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2363e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4080e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3668e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0879e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1982e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6169e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3026e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|████████  | 784/976 [03:36<00:49,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(2.2277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8597e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1180e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7879e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8376e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5101e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1017e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9615e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 80%|████████  | 785/976 [03:37<00:51,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(1.4866e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1845e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3953e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4262e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2797e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4196e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5678e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 786/976 [03:37<00:51,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(1.1912e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2902e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1307e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4665e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4140e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1770e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8942e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3993e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 787/976 [03:37<00:52,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(1.2186e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1307e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5476e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0858e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3718e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5361e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6507e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1659e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 788/976 [03:37<00:52,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.2061e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1525e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1201e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3624e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0627e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4921e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8788e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 789/976 [03:38<00:49,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.9733e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9068e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1064e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3335e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5033e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7604e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3892e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2991e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 790/976 [03:38<00:48,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(1.1721e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4862e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3560e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2421e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2602e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4034e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3981e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 791/976 [03:38<00:49,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.9339e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6728e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9139e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4529e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7540e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1916e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2501e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████  | 792/976 [03:39<00:49,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(3.4428e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4542e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9670e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8944e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3084e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3968e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████▏ | 793/976 [03:39<00:51,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.9126e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0786e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9280e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9991e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4055e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6633e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8835e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9425e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████▏ | 794/976 [03:39<00:52,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(7.1595e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8946e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7981e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4920e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7440e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0462e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2250e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2905e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 81%|████████▏ | 795/976 [03:39<00:51,  3.49it/s]

torch.Size([])
torch.Size([128])
tensor(1.0522e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2817e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5829e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3119e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1709e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1756e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3858e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1131e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 796/976 [03:40<00:52,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(2.6983e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4956e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8836e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3809e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2803e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2877e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 797/976 [03:40<00:51,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(1.6058e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3850e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6495e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2071e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5123e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8226e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 798/976 [03:40<00:50,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(9.4790e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9487e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9565e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1558e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0119e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8049e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5501e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 799/976 [03:41<00:47,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(1.0460e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3413e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2887e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5180e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7803e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1510e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 800/976 [03:41<00:46,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(5.8520e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4864e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0252e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0133e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9357e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8572e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8464e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7666e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 801/976 [03:41<00:44,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(1.0320e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4049e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5656e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5548e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3244e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4469e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1670e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7010e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 802/976 [03:41<00:43,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(9.5455e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1265e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3288e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6954e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9878e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4316e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4620e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 803/976 [03:41<00:43,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(4.7857e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8913e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2033e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4962e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4949e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1179e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 804/976 [03:42<00:42,  4.04it/s]

torch.Size([])
torch.Size([128])
tensor(9.8342e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8367e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8609e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6314e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7187e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2625e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1373e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0079e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 82%|████████▏ | 805/976 [03:42<00:43,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(1.2946e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3035e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3220e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4786e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2472e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4223e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3977e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2412e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 806/976 [03:42<00:44,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(1.3166e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7068e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5894e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3434e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7229e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3982e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4745e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3761e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 807/976 [03:43<00:42,  3.93it/s]

torch.Size([])
torch.Size([128])
tensor(1.9192e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5230e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0602e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6590e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8898e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7225e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8969e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6723e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 808/976 [03:43<00:42,  3.95it/s]

torch.Size([])
torch.Size([128])
tensor(2.3125e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9226e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4649e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2303e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2622e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6179e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3962e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7212e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 809/976 [03:43<00:43,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(1.0311e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0375e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1105e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0304e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0155e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2800e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5213e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2231e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 810/976 [03:43<00:45,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(1.6261e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2017e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3248e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9303e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5656e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2223e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 811/976 [03:44<00:48,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(3.8647e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9446e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2373e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7471e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4539e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8142e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2648e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 812/976 [03:44<01:00,  2.73it/s]

torch.Size([])
torch.Size([128])
tensor(2.5049e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7018e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0545e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7727e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4349e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9319e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 813/976 [03:45<01:07,  2.43it/s]

torch.Size([])
torch.Size([128])
tensor(7.9955e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0504e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4346e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9010e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1723e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2682e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8658e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0503e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 83%|████████▎ | 814/976 [03:45<01:11,  2.27it/s]

torch.Size([])
torch.Size([128])
tensor(8.7422e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9906e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4656e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9702e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2348e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9780e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7382e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1344e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▎ | 815/976 [03:46<01:28,  1.82it/s]

torch.Size([])
torch.Size([128])
tensor(1.2477e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0400e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5066e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7262e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7672e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4529e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3852e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5204e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▎ | 816/976 [03:46<01:14,  2.14it/s]

torch.Size([])
torch.Size([128])
tensor(2.1308e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8512e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3690e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8171e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7796e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8194e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▎ | 817/976 [03:47<01:04,  2.46it/s]

torch.Size([])
torch.Size([128])
tensor(1.7040e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3553e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1555e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3063e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5818e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4121e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1354e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3269e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 818/976 [03:47<00:57,  2.77it/s]

torch.Size([])
torch.Size([128])
tensor(1.3247e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1207e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0887e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1614e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1199e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1638e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0702e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2221e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 819/976 [03:47<00:52,  2.98it/s]

torch.Size([])
torch.Size([128])
tensor(1.5549e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3386e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1589e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0286e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9554e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8685e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2695e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 820/976 [03:47<00:50,  3.11it/s]

torch.Size([])
torch.Size([128])
tensor(1.9202e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1767e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4816e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4535e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2412e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8786e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1081e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8813e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 821/976 [03:48<00:45,  3.39it/s]

torch.Size([])
torch.Size([128])
tensor(6.9107e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8823e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7532e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4045e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9449e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7672e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8563e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4603e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 822/976 [03:48<00:45,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(1.0103e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3838e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1158e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0621e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0068e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3628e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7135e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8328e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 823/976 [03:48<00:45,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(8.3881e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2646e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8585e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2082e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1097e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6256e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9634e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2118e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 84%|████████▍ | 824/976 [03:49<00:43,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(5.0844e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3214e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8057e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6429e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0252e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5213e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0462e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3060e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 825/976 [03:49<00:42,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(1.5552e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1958e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5683e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2792e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3088e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4068e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5745e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5175e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 826/976 [03:49<00:41,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.2601e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6540e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4150e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5609e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5139e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1687e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6082e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3238e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 827/976 [03:49<00:41,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(3.3683e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5315e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4247e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3482e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1352e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3807e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5582e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5920e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 828/976 [03:50<00:40,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(1.0408e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0955e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1081e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0178e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1059e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0905e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1951e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9671e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▍ | 829/976 [03:50<00:39,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(1.3895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5684e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3406e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3825e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0123e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0311e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4577e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3407e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▌ | 830/976 [03:50<00:39,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(9.2723e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0152e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9879e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1337e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8266e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0600e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▌ | 831/976 [03:50<00:38,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(2.3726e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7544e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4714e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8098e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3706e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9495e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8972e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8056e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▌ | 832/976 [03:51<00:38,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(9.2111e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2184e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0716e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0149e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1302e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0596e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0207e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▌ | 833/976 [03:51<00:39,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(1.1893e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6211e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2456e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3926e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5169e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4794e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0628e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3433e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 85%|████████▌ | 834/976 [03:51<00:39,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(1.8620e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7152e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9581e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5800e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6763e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7308e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7713e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4735e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 835/976 [03:52<00:39,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.2431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7769e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8049e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0712e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6816e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9161e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9367e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3630e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 836/976 [03:52<00:40,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(6.7965e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3856e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2926e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3162e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6071e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3761e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7501e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6624e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 837/976 [03:52<00:39,  3.50it/s]

torch.Size([])
torch.Size([128])
tensor(1.0860e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6756e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9534e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0653e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6615e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7526e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2564e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9454e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 838/976 [03:52<00:37,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(1.1810e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7997e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7483e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3023e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8651e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0254e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7230e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8121e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 839/976 [03:53<00:37,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.3318e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4120e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2351e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0797e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5572e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3151e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2711e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9112e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 840/976 [03:53<00:36,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(9.0738e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2286e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0745e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6623e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1831e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1404e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9338e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0147e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▌ | 841/976 [03:53<00:36,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(1.2642e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6544e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3394e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1492e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0136e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4199e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5162e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▋ | 842/976 [03:53<00:36,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(1.7649e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4934e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4692e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7025e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8083e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8511e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5438e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1750e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▋ | 843/976 [03:54<00:34,  3.88it/s]

torch.Size([])
torch.Size([128])
tensor(2.3538e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5364e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0455e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4128e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6004e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7266e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6922e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3601e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 86%|████████▋ | 844/976 [03:54<00:35,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(1.1747e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1735e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8822e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0634e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0702e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1184e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9075e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2184e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 845/976 [03:54<00:36,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(6.8012e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6010e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8173e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2600e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1693e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6911e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9761e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5645e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 846/976 [03:55<00:36,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(1.4931e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0995e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7525e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3543e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3401e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3581e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4781e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4349e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 847/976 [03:55<00:36,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.6844e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8767e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9015e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0399e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6649e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6437e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9131e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 848/976 [03:55<00:35,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(1.6895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5528e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1697e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0339e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0247e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0944e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5369e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0071e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 849/976 [03:55<00:33,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(8.7603e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3600e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8816e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0632e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9134e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8242e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0519e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2364e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 850/976 [03:56<00:32,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(1.4118e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0516e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1173e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9756e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0082e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3855e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0301e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3174e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 851/976 [03:56<00:31,  3.93it/s]

torch.Size([])
torch.Size([128])
tensor(4.6188e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9180e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8514e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6833e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3525e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5333e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0853e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0694e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 852/976 [03:56<00:30,  4.02it/s]

torch.Size([])
torch.Size([128])
tensor(1.5267e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2594e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5023e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5102e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7374e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3454e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2492e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4914e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 87%|████████▋ | 853/976 [03:56<00:30,  4.01it/s]

torch.Size([])
torch.Size([128])
tensor(1.0828e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4359e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3776e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1880e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7773e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0633e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0355e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1378e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 854/976 [03:57<00:30,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(1.4057e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4745e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1655e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0061e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3166e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8954e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3352e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2013e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 855/976 [03:57<00:31,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(4.7989e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9493e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8233e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9761e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4273e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8924e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6930e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 856/976 [03:57<00:31,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(9.0174e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8823e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2870e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0010e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5985e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6269e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0624e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 857/976 [03:57<00:31,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(7.3358e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5312e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4434e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6905e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0829e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2576e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9312e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2773e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 858/976 [03:58<00:32,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(3.3134e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7947e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9169e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3875e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9554e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0375e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6833e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 859/976 [03:58<00:32,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.0572e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9928e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7076e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3537e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4844e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1624e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9708e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4039e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 860/976 [03:58<00:32,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(9.2116e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0987e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4035e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0043e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2859e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4534e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1135e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0733e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 861/976 [03:59<00:33,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(1.1735e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1439e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0382e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0316e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2983e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4147e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0225e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1951e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 862/976 [03:59<00:32,  3.46it/s]

torch.Size([])
torch.Size([128])
tensor(1.2891e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7354e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5785e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4937e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4511e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5460e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 88%|████████▊ | 863/976 [03:59<00:33,  3.42it/s]

torch.Size([])
torch.Size([128])
tensor(2.5346e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0457e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8927e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0734e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4642e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3995e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9873e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▊ | 864/976 [03:59<00:32,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(8.3028e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9045e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2704e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6206e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4282e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1529e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0398e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0700e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▊ | 865/976 [04:00<00:32,  3.40it/s]

torch.Size([])
torch.Size([128])
tensor(8.6172e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2082e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3194e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1051e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.7010e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9720e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3653e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▊ | 866/976 [04:00<00:32,  3.43it/s]

torch.Size([])
torch.Size([128])
tensor(3.9803e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0376e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4534e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2900e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0428e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0636e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3328e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 867/976 [04:00<00:31,  3.41it/s]

torch.Size([])
torch.Size([128])
tensor(1.4174e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4753e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2212e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4654e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3976e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5289e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2997e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3100e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 868/976 [04:01<00:30,  3.51it/s]

torch.Size([])
torch.Size([128])
tensor(2.0145e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4259e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4320e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3253e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2024e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3432e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0815e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8681e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 869/976 [04:01<00:30,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(1.2872e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5070e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3329e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9419e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2201e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1720e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5075e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4412e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 870/976 [04:01<00:30,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(2.6670e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5849e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4753e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8808e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2051e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6110e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9219e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6385e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 871/976 [04:01<00:29,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(2.0541e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5793e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4307e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6066e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9740e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6098e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9067e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 872/976 [04:02<00:30,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(7.5466e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5109e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4265e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8422e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2111e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4153e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5660e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3972e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 89%|████████▉ | 873/976 [04:02<00:29,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(6.9340e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8929e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1788e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5681e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9596e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8765e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8826e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7763e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 874/976 [04:02<00:27,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(2.2936e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1878e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3817e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1589e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8708e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3892e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3372e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 875/976 [04:03<00:28,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.3322e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1566e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1980e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0953e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2450e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0568e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3121e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2238e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 876/976 [04:03<00:27,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(9.9321e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1857e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3553e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3360e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5403e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1479e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1586e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 877/976 [04:03<00:26,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(1.2958e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6946e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0228e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5974e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7410e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8495e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1130e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7635e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|████████▉ | 878/976 [04:03<00:26,  3.69it/s]

torch.Size([])
torch.Size([128])
tensor(4.4954e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1288e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1068e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3342e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8728e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5600e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5272e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|█████████ | 879/976 [04:04<00:26,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(1.6878e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7828e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8141e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9378e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9950e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7805e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4013e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9833e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|█████████ | 880/976 [04:04<00:26,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(6.1504e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3391e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8205e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6647e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8896e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1708e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4665e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0797e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|█████████ | 881/976 [04:04<00:25,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(7.1813e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7862e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8310e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0869e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0834e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5650e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4035e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2398e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|█████████ | 882/976 [04:04<00:25,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(3.0983e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8702e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9141e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7048e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7464e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5976e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0791e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2115e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 90%|█████████ | 883/976 [04:05<00:26,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(1.5370e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5388e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1881e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6501e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7089e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5704e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8427e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 884/976 [04:05<00:26,  3.53it/s]

torch.Size([])
torch.Size([128])
tensor(7.7465e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8939e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4369e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9915e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0008e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8044e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6576e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2385e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 885/976 [04:05<00:26,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(1.2424e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4318e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1652e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4047e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3753e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6323e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8774e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 886/976 [04:06<00:25,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(3.0396e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8116e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1529e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0994e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6221e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3990e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0480e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0097e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 887/976 [04:06<00:25,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(1.3190e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0527e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.1554e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5941e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7839e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2540e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7815e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9825e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 888/976 [04:06<00:23,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(1.0163e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8786e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9369e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3944e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2873e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1773e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4030e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8937e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 889/976 [04:06<00:23,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(2.1672e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7062e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9964e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3537e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0598e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2721e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8966e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████ | 890/976 [04:07<00:23,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(3.5398e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6044e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6224e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8865e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1178e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2610e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7653e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████▏| 891/976 [04:07<00:23,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(6.7874e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7697e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1498e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7504e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7244e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3307e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1991e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3387e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████▏| 892/976 [04:07<00:23,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(5.5002e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6711e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2837e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6179e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8047e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2151e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0428e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0386e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 91%|█████████▏| 893/976 [04:08<00:23,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(1.8930e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9063e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1865e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2079e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1084e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6236e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4115e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 894/976 [04:08<00:21,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(1.6572e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7488e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7524e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1169e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9289e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7667e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7153e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 895/976 [04:08<00:21,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(6.5582e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8210e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9762e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2927e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8951e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7214e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4358e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6826e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 896/976 [04:08<00:22,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(4.2737e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9900e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8897e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7353e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7031e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5553e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8656e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9807e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 897/976 [04:09<00:22,  3.57it/s]

torch.Size([])
torch.Size([128])
tensor(5.9404e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6310e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2474e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8727e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5045e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5795e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0682e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 898/976 [04:09<00:21,  3.62it/s]

torch.Size([])
torch.Size([128])
tensor(2.1459e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8538e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0002e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7810e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6369e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1610e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 899/976 [04:09<00:20,  3.82it/s]

torch.Size([])
torch.Size([128])
tensor(4.9357e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6144e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5837e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6152e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9416e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4682e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3803e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4208e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 900/976 [04:09<00:19,  3.86it/s]

torch.Size([])
torch.Size([128])
tensor(6.5405e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4904e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1912e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.6830e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5144e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0376e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8295e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7397e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 901/976 [04:10<00:18,  3.98it/s]

torch.Size([])
torch.Size([128])
tensor(1.9181e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6542e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3673e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4329e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3326e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1182e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4131e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5126e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 92%|█████████▏| 902/976 [04:10<00:18,  4.10it/s]

torch.Size([])
torch.Size([128])
tensor(2.1930e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0778e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6759e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5601e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6043e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3893e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7852e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9862e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 903/976 [04:10<00:17,  4.19it/s]

torch.Size([])
torch.Size([128])
tensor(5.1226e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5314e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3042e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4535e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2437e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7854e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8310e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0666e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 904/976 [04:10<00:17,  4.17it/s]

torch.Size([])
torch.Size([128])
tensor(7.6098e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9958e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3903e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1732e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0515e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6822e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4979e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0325e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 905/976 [04:11<00:16,  4.19it/s]

torch.Size([])
torch.Size([128])
tensor(2.3104e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4621e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5857e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3245e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2224e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3912e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0461e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8426e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 906/976 [04:11<00:17,  3.99it/s]

torch.Size([])
torch.Size([128])
tensor(9.6321e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2120e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0341e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0610e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8060e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1590e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2318e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2390e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 907/976 [04:11<00:17,  3.84it/s]

torch.Size([])
torch.Size([128])
tensor(4.1698e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0598e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9222e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1153e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0224e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7097e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3641e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3897e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 908/976 [04:11<00:17,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(6.5795e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8478e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7086e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7330e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9707e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7454e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3670e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7615e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 909/976 [04:12<00:18,  3.71it/s]

torch.Size([])
torch.Size([128])
tensor(4.4711e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4232e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2648e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9388e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5414e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7451e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2091e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5380e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 910/976 [04:12<00:18,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(2.0819e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6241e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5481e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5747e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0504e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9749e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7575e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8748e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 911/976 [04:12<00:18,  3.59it/s]

torch.Size([])
torch.Size([128])
tensor(4.1965e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9872e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1657e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9587e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7673e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2161e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9118e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2602e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 93%|█████████▎| 912/976 [04:12<00:16,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(6.2913e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4989e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6802e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3772e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7274e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0188e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1980e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4944e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▎| 913/976 [04:13<00:17,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(2.0701e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7327e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0206e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8903e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4065e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7212e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5264e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9301e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▎| 914/976 [04:13<00:17,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(9.3696e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8096e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9765e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.9627e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0726e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7086e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.8952e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3371e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 915/976 [04:13<00:16,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(4.8072e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6586e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1787e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7327e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1629e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7298e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3419e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0794e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 916/976 [04:14<00:16,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(5.5326e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0948e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4986e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6986e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8652e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3045e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8629e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1574e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 917/976 [04:14<00:16,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(3.2430e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0616e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7913e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5211e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9085e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0880e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1251e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3893e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 918/976 [04:14<00:15,  3.78it/s]

torch.Size([])
torch.Size([128])
tensor(7.9612e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3695e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0240e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8676e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8998e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0193e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5247e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.3651e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 919/976 [04:14<00:15,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(4.9845e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2028e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7735e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7705e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9736e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8962e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2917e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0250e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 920/976 [04:15<00:14,  3.76it/s]

torch.Size([])
torch.Size([128])
tensor(3.8352e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1260e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5840e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2917e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1627e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4940e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6385e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7027e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 921/976 [04:15<00:14,  3.91it/s]

torch.Size([])
torch.Size([128])
tensor(4.5301e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.4672e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4981e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6629e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7833e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4900e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9991e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7246e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 94%|█████████▍| 922/976 [04:15<00:14,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(4.8011e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8995e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3316e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3515e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6074e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7948e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6555e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6585e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 923/976 [04:15<00:14,  3.75it/s]

torch.Size([])
torch.Size([128])
tensor(4.4788e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3347e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5268e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7806e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0658e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0052e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7470e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3197e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 924/976 [04:16<00:13,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(4.6987e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1649e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3460e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5350e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4866e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3269e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6733e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5815e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 925/976 [04:16<00:13,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(2.4132e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7841e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0306e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2251e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5902e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7773e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6467e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 926/976 [04:16<00:13,  3.73it/s]

torch.Size([])
torch.Size([128])
tensor(3.5935e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2497e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6849e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5536e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9790e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2490e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7320e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8767e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▍| 927/976 [04:16<00:13,  3.77it/s]

torch.Size([])
torch.Size([128])
tensor(4.2785e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6842e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8268e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.2578e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.5546e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8891e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3483e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1451e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▌| 928/976 [04:17<00:12,  3.80it/s]

torch.Size([])
torch.Size([128])
tensor(1.0431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2099e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4850e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0665e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2751e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2803e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6840e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0637e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▌| 929/976 [04:17<00:13,  3.60it/s]

torch.Size([])
torch.Size([128])
tensor(3.8811e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1907e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5127e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0139e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7440e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1052e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9549e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7619e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▌| 930/976 [04:17<00:12,  3.56it/s]

torch.Size([])
torch.Size([128])
tensor(4.5965e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3317e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3785e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7560e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9703e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6132e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5123e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7840e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▌| 931/976 [04:18<00:12,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(5.3389e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2314e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7816e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1657e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2575e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2640e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1781e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3588e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 95%|█████████▌| 932/976 [04:18<00:11,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(1.0633e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0190e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2289e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8447e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.8132e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.9536e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.2326e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1014e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 933/976 [04:18<00:11,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(2.4746e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7802e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8574e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9970e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2656e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0344e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1389e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8155e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 934/976 [04:18<00:11,  3.58it/s]

torch.Size([])
torch.Size([128])
tensor(4.2200e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3852e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0963e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5101e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0460e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0006e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2777e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4647e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 935/976 [04:19<00:11,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(8.9722e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1461e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1629e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0603e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0815e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2885e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1021e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 936/976 [04:19<00:11,  3.48it/s]

torch.Size([])
torch.Size([128])
tensor(9.1164e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2477e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0374e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.6736e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1646e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1540e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.4025e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9590e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 937/976 [04:19<00:10,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(1.7467e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.0719e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9810e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8528e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3330e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7167e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1156e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9089e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 938/976 [04:20<00:10,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(3.3517e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2435e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6630e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8645e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1820e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8358e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7405e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.0029e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▌| 939/976 [04:20<00:10,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(7.6719e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9912e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0878e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2883e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4405e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0455e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8265e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2740e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▋| 940/976 [04:20<00:09,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(1.3821e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4986e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6007e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3411e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6824e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6276e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5478e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6158e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 96%|█████████▋| 941/976 [04:20<00:09,  3.65it/s]

torch.Size([])
torch.Size([128])
tensor(6.0982e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1877e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9268e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1025e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3785e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6869e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0257e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 942/976 [04:21<00:09,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(4.1772e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2180e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0868e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3748e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8513e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6966e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3855e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3447e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 943/976 [04:21<00:09,  3.55it/s]

torch.Size([])
torch.Size([128])
tensor(4.7402e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4555e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6388e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8571e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3794e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9857e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2421e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3705e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 944/976 [04:21<00:09,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(8.8569e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9848e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0078e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7094e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0459e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0871e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4457e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6627e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 945/976 [04:22<00:08,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(1.2482e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7445e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2972e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6823e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8076e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7195e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5172e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0742e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 946/976 [04:22<00:08,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(3.5200e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2776e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6092e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1084e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6990e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4762e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5122e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6694e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 947/976 [04:22<00:08,  3.45it/s]

torch.Size([])
torch.Size([128])
tensor(9.4220e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2269e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8049e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7514e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2205e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.7683e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2006e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0067e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 948/976 [04:22<00:07,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(1.7466e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3854e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4578e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2809e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6816e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4518e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3458e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3445e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 949/976 [04:23<00:07,  3.74it/s]

torch.Size([])
torch.Size([128])
tensor(1.5254e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5046e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0470e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0943e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0985e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4528e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1310e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0341e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 950/976 [04:23<00:06,  3.81it/s]

torch.Size([])
torch.Size([128])
tensor(6.8253e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2904e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6490e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.6090e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7452e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3170e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2051e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2567e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 97%|█████████▋| 951/976 [04:23<00:06,  3.94it/s]

torch.Size([])
torch.Size([128])
tensor(4.6684e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8974e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4026e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.2595e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8066e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7568e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9048e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7251e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 952/976 [04:23<00:05,  4.06it/s]

torch.Size([])
torch.Size([128])
tensor(1.4335e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4233e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.1025e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8534e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9034e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7359e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4421e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6456e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 953/976 [04:24<00:05,  4.15it/s]

torch.Size([])
torch.Size([128])
tensor(7.3754e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1981e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8425e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7878e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2662e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.1458e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3193e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.0250e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 954/976 [04:24<00:05,  4.17it/s]

torch.Size([])
torch.Size([128])
tensor(3.2191e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3855e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6229e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8409e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1156e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9161e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5095e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7632e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 955/976 [04:24<00:05,  4.20it/s]

torch.Size([])
torch.Size([128])
tensor(8.5039e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.5316e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.3468e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1939e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.8650e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4047e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.5893e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7931e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 956/976 [04:24<00:04,  4.01it/s]

torch.Size([])
torch.Size([128])
tensor(1.7521e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5827e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2717e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0399e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4294e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6960e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.0654e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2182e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 957/976 [04:25<00:04,  3.90it/s]

torch.Size([])
torch.Size([128])
tensor(8.0701e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7902e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3209e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.7053e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.0839e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(8.4008e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.0319e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6990e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 958/976 [04:25<00:04,  4.04it/s]

torch.Size([])
torch.Size([128])
tensor(3.5733e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4852e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6114e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.4617e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1641e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6137e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1685e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0596e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 959/976 [04:25<00:04,  3.97it/s]

torch.Size([])
torch.Size([128])
tensor(3.7193e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2179e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.6475e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9864e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8166e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4032e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9130e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8683e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 960/976 [04:25<00:04,  3.79it/s]

torch.Size([])
torch.Size([128])
tensor(2.3431e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7952e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7943e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.8267e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2150e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5696e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9877e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9367e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 98%|█████████▊| 961/976 [04:26<00:04,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(7.1996e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6751e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.8702e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.3307e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.9237e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.9466e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.4441e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3064e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▊| 962/976 [04:26<00:03,  3.63it/s]

torch.Size([])
torch.Size([128])
tensor(3.1374e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1451e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.0189e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5367e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3273e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.5712e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.7861e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9994e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▊| 963/976 [04:26<00:03,  3.52it/s]

torch.Size([])
torch.Size([128])
tensor(3.1375e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8071e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3559e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1006e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8699e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2127e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6261e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7177e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 964/976 [04:27<00:03,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(2.3583e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6834e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8573e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2338e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6979e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9594e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9577e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 965/976 [04:27<00:03,  3.47it/s]

torch.Size([])
torch.Size([128])
tensor(3.4868e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6095e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3235e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4209e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7758e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5012e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.1346e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1542e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 966/976 [04:27<00:02,  3.54it/s]

torch.Size([])
torch.Size([128])
tensor(7.4513e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.2573e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3710e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.4677e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.1569e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.8884e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5309e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.3836e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 967/976 [04:27<00:02,  3.66it/s]

torch.Size([])
torch.Size([128])
tensor(4.7599e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2732e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8753e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6929e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5398e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3180e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7955e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.2981e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 968/976 [04:28<00:02,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(3.6895e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.1505e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.5811e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3039e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.3903e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4209e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.9183e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0102e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 969/976 [04:28<00:01,  3.70it/s]

torch.Size([])
torch.Size([128])
tensor(3.4153e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3028e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0208e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9252e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9714e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.8314e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.2293e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8611e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 970/976 [04:28<00:01,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(5.1878e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.0169e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8663e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.8009e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.7320e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.9426e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.6658e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4724e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

 99%|█████████▉| 971/976 [04:28<00:01,  3.64it/s]

torch.Size([])
torch.Size([128])
tensor(2.3513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.7678e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.6586e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5435e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2028e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5346e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7598e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.9170e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|█████████▉| 972/976 [04:29<00:01,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(2.8342e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2715e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.5379e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.5331e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.7811e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6486e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.3376e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5455e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|█████████▉| 973/976 [04:29<00:00,  3.68it/s]

torch.Size([])
torch.Size([128])
tensor(2.3759e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.3084e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.5994e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(6.2489e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7852e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.6382e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(2.2155e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(3.4892e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|█████████▉| 974/976 [04:29<00:00,  3.72it/s]

torch.Size([])
torch.Size([128])
tensor(5.5703e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.4237e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.9956e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.3855e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.7136e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9184e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(4.9223e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.1576e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|█████████▉| 975/976 [04:30<00:00,  3.67it/s]

torch.Size([])
torch.Size([128])
tensor(7.0890e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.2277e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.3042e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(7.6244e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(5.6019e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.1801e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.1009e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(9.0929e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])

100%|██████████| 976/976 [04:30<00:00,  3.61it/s]

torch.Size([])
torch.Size([128])
tensor(1.5270e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.8943e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4642e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4490e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.7513e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5325e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.5617e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])
tensor(1.4261e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)
torch.Size([])
torch.Size([128])




<details>
<summary>Question - if you've done this correctly (and logged everything), clipped surrogate objective will be close to zero. Does this mean that this term is not the most important in the objective function?</summary>

No, this doesn't necessarily mean that the term is unimportant.

Clipped surrogate objective is a moving target. At each rollout phase, we generate new experiences, and the expected value of the clipped surrogate objective will be zero (because the expected value of advantages is zero). But this doesn't mean that differentiating clipped surrogate objective wrt the policy doesn't have a large gradient! It's the gradient of the objective function that matters, not the value.

As we make update steps in the learning phase, the policy values $\pi(a_t \mid s_t)$ will increase for actions which have positive advantages, and decrease for actions which have negative advantages, so the clipped surrogate objective will no longer be zero in expectation. But (thanks to the fact that we're clipping changes larger than $\epsilon$) it will still be very small.

</details>

## Reward Shaping

Yesterday during DQN, we covered **catastrophic forgetting** - this is the phenomena whereby the replay memory mostly contains successful experiences, and the model forgets how to adapt or recover from bad states. In fact, you might find it even more severe here than for DQN, because PPO is an on-policy method (we generate a new batch of experiences for each learning phase) unlike DQN. Here's an example reward trajectory from a PPO run on CartPole, using the solution code:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/cf2.png" width="600">

<details>
<summary>A small tangent: on-policy vs off-policy algorithms</summary>

In RL, an algorithm being **on-policy** means we learn only from the most recent trajectory of experiences, and **off-policy** involves learning from a sampled trajectory of past experiences. 

It's common to describe **PPO as on-policy** (because it generates an entirely new batch of experiences each learning phase) and **DQN as off-policy** (because its replay buffer effectively acts as a memory bank for past experiences). However, it's important to remember that the line between these two can become blurry. As soon as PPO takes a single learning step it technically ceases to be on-policy because it's now learning using data generated from a slightly different version of the current policy. However, the fact that PPO uses clipping to explicitly keep its current and target policies close together is another reason why it's usually fine to refer to it as an on-policy method, unlike DQN.

</details>


We can fix catastrophic forgetting in the same way as we did yesterday (by having our replay memory keep some fraction of bad experiences from previous phases), but here we'll introduce another option - **reward shaping**. 

The rewards for `CartPole` encourage the agent to keep the episode running for as long as possible (which it then needs to associate with balancing the pole), but we can write a wrapper around the `CartPoleEnv` to modify the dynamics of the environment, and help the agent learn faster.

Try to modify the reward to make the task as easy to learn as possible. Compare this against your performance on the original environment, and see if the agent learns faster with your shaped reward. If you can bound the reward on each timestep between 0 and 1, this will make comparing the results to `CartPole-v1` easier.

<details>
<summary>Help - I'm not sure what I'm meant to return in this function.</summary>

The tuple `(obs, reward, done, info)` is returned from the CartPole environment. Here, `rew` is always 1 unless the episode has terminated.

You should change this, so that `reward` incentivises good behaviour, even if the pole hasn't fallen yet. You can use the information returned in `obs` to construct a new reward function.

</details>

<details>
<summary>Help - I'm confused about how to choose a reward function. (Try and think about this for a while before looking at this dropdown.)</summary>

Right now, the agent always gets a reward of 1 for each timestep it is active. You should try and change this so that it gets a reward between 0 and 1, which is closer to 1 when the agent is performing well / behaving stably, and equals 0 when the agent is doing very poorly.

The variables we have available to us are cart position, cart velocity, pole angle, and pole angular velocity, which I'll denote as $x$, $v$, $\theta$ and $\omega$.

Here are a few suggestions which you can try out:
* $r = 1 - (\theta / \theta_{\text{max}})^2$. This will have the effect of keeping the angle close to zero.
* $r = 1 - (x / x_{\text{max}})^2$. This will have the effect of pushing it back towards the centre of the screen (i.e. it won't tip and fall to the side of the screen).

You could also try using e.g. $|\theta / \theta_{\text{max}}|$ rather than $(\theta / \theta_{\text{max}})^2$. This would still mean reward is in the range (0, 1), but it would result in a larger penalty for very small deviations from the vertical position.

You can also try a linear combination of two or more of these rewards!
</details>


<details>
<summary>Help - my agent's episodic return is smaller than it was in the original CartPole environment.</summary>

This is to be expected, because your reward function is no longer always 1 when the agent is upright. Both your time-discounted reward estimates and your actual realised rewards will be less than they were in the cartpole environment.

For a fairer test, measure the length of your episodes - hopefully your agent learns how to stay upright for the entire 500 timestep interval as fast as or faster than it did previously.
</details>

Note - if you want to use the maximum possible values of `x` and `theta` in your reward function (to keep it bounded between 0 and 1) then you can. These values can be found at the [documentation page](https://github.com/Farama-Foundation/Gymnasium/blob/v0.29.0/gymnasium/envs/classic_control/cartpole.py#L51) (note - the table contains the max possible values, not max unterminated values - those are below the table). You can also use `self.x_threshold` and `self.theta_threshold_radians` to get these values directly (again, see the source code for how these are calculated).

### Exercise - implement reward shaping

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 15-30 minutes on this exercise.
> ```

See [this link](https://api.wandb.ai/links/callum-mcdougall/p7e739rp) for what an ideal wandb run here should look like (using the reward function in the solutions).

In [None]:
from gymnasium.envs.classic_control import CartPoleEnv


class EasyCart(CartPoleEnv):
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        # raise NotImplementedError()
        x, v, theta, omega = obs

        rew_1 = 1 - abs(theta/0.2095)
        rew_2 = 1 - abs(x/2.4)

        reward_new = (rew_1 + rew_2)/2

        return obs, reward_new, terminated, truncated, info


gym.envs.registration.register(id="EasyCart-v0", entry_point=EasyCart, max_episode_steps=500)
args = PPOArgs(env_id="EasyCart-v0", use_wandb=True, video_log_freq=50)
trainer = PPOTrainer(args)
trainer.train()

<details>
<summary>Solution (one possible implementation)</summary>

I tried out a few different simple reward functions here. One of the best ones I found used a mix of absolute value penalties for both the angle and the horizontal position (this outperformed using absolute value penalty for just one of these two). My guess as to why this is the case - penalising by horizontal position helps the agent improve its long-term strategy, and penalising by angle helps the agent improve its short-term strategy, so both combined work better than either on their own.

```python
class EasyCart(CartPoleEnv):
    def step(self, action):
        obs, rew, terminated, truncated, info = super().step(action)
        
        x, v, theta, omega = obs

        # First reward: angle should be close to zero
        rew_1 = 1 - abs(theta / 0.2095)
        # Second reward: position should be close to the center
        rew_2 = 1 - abs(x / 2.4)

        # Combine both rewards (keep it in the [0, 1] range)
        rew_new = (rew_1 + rew_2) / 2

        return obs, rew_new, terminated, truncated, info
```

The result:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/best-episode-length.png" width="600">

To illustrate the point about different forms of reward optimizing different kinds of behaviour - below are links to three videos generated during the WandB training, one of just position penalisation, one of just angle penalisation, and one of both. Can you guess which is which?

* [First video](https://wandb.ai//callum-mcdougall/PPOCart/reports/videos-23-07-07-13-48-08---Vmlldzo0ODI1NDcw?accessToken=uajtb4w1gaqkbrf2utonbg2b93lfdlw9eaet4qd9n6zuegkb3mif7l3sbuke8l4j)
* [Second video](https://wandb.ai//callum-mcdougall/PPOCart/reports/videos-23-07-07-13-47-22---Vmlldzo0ODI1NDY2?accessToken=qoss34zyuaso1b5s40nehamsk7nj93ijopmscesde6mjote0i194e7l99sg2k6dg)
* [Third video](https://wandb.ai//callum-mcdougall/PPOCart/reports/videos-23-07-07-13-45-15---Vmlldzo0ODI1NDQ4?accessToken=n1btft5zfqx0aqk8wkuh13xtp5mn19q5ga0mpjmvjnn2nq8q62xz4hsomd0vnots)

<details>
<summary>Answer</summary>

* First video = angle penalisation
* Second video = both (from the same video as the loss curve above)
* Third video = position penalisation

</details>

</details>

<br>

Now, change the environment such that the reward incentivises the agent to spin very fast. You may change the termination conditions of the environment (i.e. return a different value for `done`) if you think this will help.

See [this link](https://api.wandb.ai/links/callum-mcdougall/86y2vtsk) for what an ideal wandb run here should look like (using the reward function in the solutions).

In [93]:
class SpinCart(CartPoleEnv):
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        raise NotImplementedError()

        return (obs, reward_new, terminated, truncated, info)


gym.envs.registration.register(id="SpinCart-v0", entry_point=SpinCart, max_episode_steps=500)
args = PPOArgs(env_id="SpinCart-v0", use_wandb=True, video_log_freq=50)
trainer = PPOTrainer(args)
trainer.train()

error: XDG_RUNTIME_DIR not set in the environment.


  0%|          | 0/976 [00:00<?, ?it/s]


NotImplementedError: 

<details>
<summary>Solution (one possible implementation)</summary>

```python
class SpinCart(gym.envs.classic_control.cartpole.CartPoleEnv):
    def step(self, action):
        obs, reward, done, info = super().step(action)
        
        x, v, theta, omega = obs

        # Allow for 360-degree rotation (but keep the cart on-screen)
        done = abs(x) > self.x_threshold

        # Reward function incentivises fast spinning while staying still & near centre
        rotation_speed_reward = min(1, 0.1 * abs(omega))
        stability_penalty = max(1, abs(x / 2.5) + abs(v / 10))
        reward_new = rotation_speed_reward - 0.5 * stability_penalty

        return obs, reward_new, done, info
```

</details>

Another thing you can try is "dancing". It's up to you to define what qualifies as "dancing" - work out a sensible definition, and the reward function to incentive it.

# 4️⃣ Atari

> ##### Learning Objectives
>
> - Understand how PPO can be used in visual domains, with appropriate architectures (CNNs)
> - Understand the idea of policy and value heads
> - Train an agent to solve the Breakout environment

## Introduction

In this section, you'll extend your PPO implementation to play Atari games.

The `gymnasium` library supports a variety of different Atari games - you can find them [here](https://ale.farama.org/environments/) (if you get a message when you click on this link asking whether you want to switch to gymnasium, ignore this and proceed to the gym site). You can try whichever ones you want, but we recommend you stick with the easier environments like Pong, Breakout, and Space Invaders.

The environments in this game are very different. Rather than having observations of shape `(4,)` (representing a vector of `(x, v, theta, omega)`), the raw observations are now images of shape `(210, 160, 3)`, representing pixels in the game screen. This leads to a variety of additional challenges relative to the Cartpole environment, for example:

* We need a much larger network, because finding the optimal strategy isn't as simple as solving a basic differential equation
* Reward shaping is much more difficult, because our observations are low-level and don't contain easily-accessible information about the high-level abstractions in the game (finding these abstractions in the first place is part of the model's challenge!)

The action space is also different for each environment. For example, in Breakout, the environment has 4 actions - run the code below to see this (if you get an error, try restarting the kernel and running everything again, minus the library installs).

In [97]:
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array")

print(env.action_space)  # Discrete(4): 4 actions to choose from
print(env.observation_space)  # Box(0, 255, (210, 160, 3), uint8): an RGB image of the game screen

Discrete(4)
Box(0, 255, (210, 160, 3), uint8)


These 4 actions are "do nothing", "fire the ball", "move right", and "move left" respectively, which you can see from:

In [98]:
print(env.get_action_meanings())

['NOOP', 'FIRE', 'RIGHT', 'LEFT']


You can see more details on the game-specific [documentation page](https://ale.farama.org/environments/breakout/). On this documentation page, you can also see information like the reward for this environment. In this case, the reward comes from breaking bricks in the wall (more reward from breaking bricks higher up). This is a more challenging reward function than the one for CartPole, where a very simple strategy (move in the direction you're tipping) leads directly to a higher reward by marginally prolonging episode length.

We can also run the code below to take some random steps in our environment and animate the results:

In [99]:
def display_frames(frames: Int[Arr, "timesteps height width channels"], figsize=(4, 5)):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(frames[0])
    plt.close()

    def update(frame):
        im.set_array(frame)
        return [im]

    ani = FuncAnimation(fig, update, frames=frames, interval=100)
    display(HTML(ani.to_jshtml()))


nsteps = 150

frames = []
obs, info = env.reset()
for _ in tqdm(range(nsteps)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    frames.append(obs)

display_frames(np.stack(frames))

100%|██████████| 150/150 [00:00<00:00, 1847.55it/s]


### Playing Breakout

Just like for Cartpole and MountainCar, we're given you a Python file to play Atari games yourself. The file is called `play_breakout.py`, and running it (i.e. `python play_breakout.py`) will open up a window for you to play the game. Take note of the key instructions, which will be printed in your terminal.

You should also be able to try out other games, by changing the relevant parts of the `play_breakout.py` file to match those games' [documentation pages](https://ale.farama.org/environments/complete_list/).

## Implementational details of Atari

The [37 Implementational Details of PPO](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Atari%2Dspecific%20implementation%20details) post describes how to get PPO working for games like Atari. In the sections below, we'll go through these steps.

### Wrappers (details [#1-7](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=The%20Use%20of%20NoopResetEnv), and [#9](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Scaling%20the%20Images%20to%20Range%20%5B0%2C%201%5D))

All the extra details except for one are just wrappers on the environment, which implement specific behaviours. For example:

* **Frame Skipping** - we repeat the agent's action for a number of frames (by default 4), and sum the reward over these frames. This saves time when the model's forward pass is computationally cheaper than an environment step.
* **Image Transformations** - we resize the image from `(210, 160)` to `(L, L)` for some smaller value `L` (in this case we'll use 84), and convert it to grayscale.

We've written some environment wrappers for you (and imported some others from the `gymnasium` library), combining them all together into the `prepare_atari_env` function in the `part3_ppo/utils.py` file. You can have a read of this and see how it works, but since we're implementing these for you, you won't have to worry about them too much.

The code below visualizes the results of them (with the frames stacked across rows, so we can see them all at once). You might want to have a think about how the kind of information your actor & critic networks are getting here, and how this might make the RL task easier.

In [100]:
env_wrapped = prepare_atari_env(env)

frames = []
obs, info = env_wrapped.reset()
for _ in tqdm(range(nsteps)):
    action = env_wrapped.action_space.sample()
    obs, reward, terminated, truncated, info = env_wrapped.step(action)
    obs = einops.repeat(np.array(obs), "frames h w -> h (frames w) 3")  # stack frames across the row
    frames.append(obs)

display_frames(np.stack(frames), figsize=(12, 3))

100%|██████████| 150/150 [00:00<00:00, 1196.14it/s]


### Shared CNN for actor & critic ([detail #8](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Shared%20Nature%2DCNN%20network))

This is the most interesting one conceptually. If we have a new observation space then it naturally follows that we need a new architecture, and if we're working with images then using a convolutional neural network is reasonable. But another particularly interesting feature here is that we use a **shared architecture** for the actor and critic networks. The idea behind this is that the early layers of our model extract features from the environment (i.e. they find the high-level abstractions contained in the image), and then the actor and critic components perform **feature extraction** to turn these features into actions / value estimates respectively. This is commonly referred to as having a **policy head** and a **value head**. We'll see this idea come up later, when we perform RL on transformers.

### Exercise - rewrite `get_actor_and_critic`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

The function `get_actor_and_critic` had a boolean argument `atari`, which we ignored previously, but which we'll now return to. When this argument is `False` then the function should behave exactly as it did before (i.e. the Cartpole version), but when `True` then it should return a shared CNN architecture for the actor and critic. The architecture should be as follows (you can open it in a new tab if it's hard to see clearly):

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/mermaid-diagram-2024-11-28-154133.svg" style="background-color: #fffbe0; padding: 10px;" width="350" height="1100">

<!-- 
flowchart TD
    A["Input<br>shape = (4, L, L)"] -> B["8x8 Conv<br>32 Channels<br>Padding 0<br>Stride 4"]
    B -> C["ReLU"]
    C -> D["4x4 Conv<br>64 Channels<br>Padding 0<br>Stride 2"]
    D -> E["ReLU"]
    E -> F["3x3 Conv<br>64 Channels<br>Padding 0<br>Stride 1"]
    F -> G[Flatten]
    G -> H[ReLU]
    H -> I["Linear<br>512 outputs"]
    I -> J["ReLU"]
    J -> K1["Linear(512, n_act)<br>row_norm=0.01"]
    K1 -> L1["Actor output"]
    J -> K2["Linear(512, 1)<br>row_norm=1"]
    K2 -> L2["Critic output"]

{
  "theme": "default",
  "themeVariables": {
    "fontSize": "22px"
    }
}
-->


Note - when calculating the number of input features for the linear layer, you can assume that the value `L` is 4 modulo 8, i.e. we can write `L = 8m + 4` for some integer `m`. This will make the convolutions easier to track. You shouldn't hardcode the number of input features assuming an input shape of `(4, 84, 84)`; this is bad practice!

We leave the exercise of finding the number of input features to the linear layer as a challenge for you. If you're stuck, you can find a hint in the section below (this isn't a particularly conceptually important detail).

<details>
<summary>Help - I don't know what the number of inputs for the first linear layer should be.</summary>

You can test this empirically by just doing a forward pass through the first half of the network and seeing what the shape of the output is.

Alternatively, you can use the convolution formula. There's never any padding, so for a conv with parameters `(size, stride)`, the dimensions change as `L -> 1 + (L - size) // stride` (see the [documentation page](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)). So we have:

```
8m+4 -> 1 + (8m-4)//4 = 2m
2m   -> 1 + (2m-4)//2 = m-1
m-1  -> 1 + (m-4)//1  = m-3
```

For instance, if `L = 84` then `m = 10` and `L_new = m-3 = 7`. So the linear layer is fed 64 features of shape `(64, 7, 7)`.

</details>

Now, you can fill in the `get_actor_and_critic_atari` function below, which is called when we call `get_actor_and_critic` with `mode == "atari"`.

Note that we take the observation shape as argument, not the number of observations. It should be `(4, L, L)` as indicated by the diagram. The shape `(4, L, L)` is a reflection of the fact that we're using 4 frames of history per input (which helps the model calculate things like velocity), and each of these frames is a monochrome resized square image.

In [103]:
def get_actor_and_critic_atari(obs_shape: tuple[int,], num_actions: int) -> tuple[nn.Sequential, nn.Sequential]:
    """
    Returns (actor, critic) in the "atari" case, according to diagram above.
    """
    assert obs_shape[-1] % 8 == 4

    # raise NotImplementedError()
    L_after_conv = (obs_shape[-1] // 8) - 3
    in_feat_Linear = 64 * L_after_conv * L_after_conv

    shared = nn.Sequential(
        layer_init(nn.Conv2d(4, 32, 8, 4, 0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(32, 64, 4, 2, 0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(64, 64, 3, 1, 0)),
        nn.Flatten(),
        nn.ReLU(),
        layer_init(nn.Linear(in_feat_Linear, 512))
    )
    actor = nn.Sequential(shared, layer_init(nn.Linear(512, num_actions), std = 0.01))
    critic = nn.Sequential(shared, layer_init(nn.Linear(512, 1), std = 1))
    return actor, critic

tests.test_get_actor_and_critic(get_actor_and_critic, mode="atari")

All tests in `test_get_actor_and_critic(mode='atari')` passed!


<details><summary>Solution</summary>

```python
def get_actor_and_critic_atari(obs_shape: tuple[int,], num_actions: int) -> tuple[nn.Sequential, nn.Sequential]:
    """
    Returns (actor, critic) in the "atari" case, according to diagram above.
    """
    assert obs_shape[-1] % 8 == 4

    L_after_convolutions = (obs_shape[-1] // 8) - 3
    in_features = 64 * L_after_convolutions * L_after_convolutions

    hidden = nn.Sequential(
        layer_init(nn.Conv2d(4, 32, 8, stride=4, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(32, 64, 4, stride=2, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(64, 64, 3, stride=1, padding=0)),
        nn.ReLU(),
        nn.Flatten(),
        layer_init(nn.Linear(in_features, 512)),
        nn.ReLU(),
    )

    actor = nn.Sequential(hidden, layer_init(nn.Linear(512, num_actions), std=0.01))
    critic = nn.Sequential(hidden, layer_init(nn.Linear(512, 1), std=1))

    return actor, critic
```
</details>

## Training Atari

Now, you should be able to run an Atari training loop!

We recommend you use the following parameters, for fidelity:

In [None]:
args = PPOArgs(
    env_id="ALE/Breakout-v5",
    wandb_project_name="PPOAtari",
    use_wandb=True,
    mode="atari",
    clip_coef=0.1,
    num_envs=8,
    video_log_freq=25,
)
trainer = PPOTrainer(args)
trainer.train()

  0%|          | 0/488 [00:00<?, ?it/s]

Note that this will probably take a lot longer to train than your previous experiments, because the architecture is much larger, and finding an initial strategy is much harder than it was for CartPole. Don't worry if it starts off with pretty bad performance (on my machine the code above takes about 40 minutes to run, and I only start seeing any improvement after about the 5-10 minute mark, or approx 70k total agent steps). You can always experiment with different methods to try and boost performance early on, like an entroy bonus which is initially larger then decays (analogous to our epsilon scheduling in DQN, which would reduce the probability of exploration over time).

Here is a video produced from a successful run, using the parameters above:

<video width="320" height="480" controls>
<source src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/media-23/2304.mp4" type="video/mp4">
</video>

and here's the corresponding plot of episodic returns (with episoic lengths following a similar pattern):

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/wandb-atari-returns.png" width="550">

### A note on debugging crashed kernels

> *This section is more relevant if you're doing these exercises on VSCode; you can skip it if you're in Colab.*

Because the `gymnasium` library is a bit fragile, sometimes you can get uninformative kernel errors like this:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/kernel_error.png" width="600">

which annoyingly doesn't tell you much about the nature or location of the error. When this happens, it's often good practice to replace your code with lower-level code bit by bit, until the error message starts being informative.

For instance, you might start with `trainer.train()`, and if this fails without an informative error message then you might try replacing this function call with the actual contents of the `train` function (which should involve the methods `trainer.rollout_phase()` and `trainer.learning_phase()`). If the problem is in `rollout_phase`, you can again replace this line with the actual contents of this method.

If you're working in `.py` files rather than `.ipynb`, a useful tip - as well as running `Shift + Enter` to run the cell your cursor is in, if you have text highlighted (and you've turned on `Send Selection To Interactive Window` in VSCode settings) then using `Shift + Enter` will run just the code you've highlighted. This could be a single variable name, a single line, or a single block of code.

# 5️⃣ Mujoco

> ##### Learning Objectives
>
> - Understand how PPO can be used to train agents in continuous action spaces
> - Install and interact with the MuJoCo physics engine
> - Train an agent to solve the Hopper environment

> An important note - **mujoco environments are notoriously demanding when it comes to having exactly the right library installs and versions.** For one thing, they require Python version `<3.11`, which means they currently won't **work in Colab** until a fix for this problem is found. To get them working, you'll need to go through the process of creating a new virtual environment and installing Python version `3.10` - we recommend creating a Linux-based via the instructions in the [Streamlit homepage](https://arena-chapter2-rl.streamlit.app/#how-to-access-the-course), and swapping `python=3.11` for `python=3.10` when you go through the "Workspace setup instructions" code.

## Installation & Rendering

Once you've gone through the step described above (a new virtual env with Python version `3.10`), you should re-run all the imports up to this point in the file, as well as the following code which will install all the Mujoco-specific packages and depenedencies:

In [None]:
%pip install mujoco free-mujoco-py

!sudo apt-get install -y libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev software-properties-common
!sudo apt-get install -y patchelf

To test that this works, run the following. The first time you run this, it might take about 1-2 minutes, and throw up several warnings and messages. But the cell should still run without raising an exception, and all subsequent times you run it, it should be a lot faster (with no error messages).

In [None]:
env = gym.make("Hopper-v4", render_mode="rgb_array")

print(env.action_space)
print(env.observation_space)

Previously, we've dealt with discrete action spaces (e.g. going right or left in Cartpole). But here, we have a continuous action space - the actions take the form of a vector of 3 values, each in the range `[-1.0, 1.0]`. 

<details>
<summary>Question - after reading the <a href="https://gymnasium.farama.org/environments/mujoco/hopper/">documentation page</a>, can you see exactly what our 3 actions mean?</summary>

They represent the **torque** applied between the three different links of the hopper. There is:

* The **thigh rotor** (i.e. connecting the upper and middle parts of the leg),
* The **leg rotor** (i.e. connecting the middle and lower parts of the leg),
* The **foot rotor** (i.e. connecting the lower part of the leg to the foot).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/hopper-torque.png" width="400">

</details>

How do we deal with a continuous action space, when it comes to choosing actions? Rather than our actor network's output being a vector of `logits` which we turn into a probability distribution via `Categorical(logits=logits)`, we instead have our actor output two vectors `mu` and `log_sigma`, which we turn into a normal distribution which is then sampled from.

The observations take the form of a vector of 11 values describing the position, velocity, and forces applied to the joints. So unlike for Atari, we can't directly visualize the environment using its observations, instead we'll visualize it using `env.render()` which returns an array representing the environment state (thanks to the fact that we initialized the env with `render_mode="rgb_array"`).

In [None]:
nsteps = 150

frames = []
obs, info = env.reset()
for _ in tqdm(range(nsteps)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    frames.append(env.render())  # frames can't come from obs, because unlike in Atari our observations aren't images

display_frames(np.stack(frames))

## Implementational details of MuJoCo

### Clipping, Scaling & Normalization ([details #5-9](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Handling%20of%20action%20clipping%20to%20valid%20range%20and%20storage))

Just like for Atari, there are a few messy implementational details which will be taken care of with gym wrappers. For example, if we generate our actions by sampling from a normal distribution, then there's some non-zero chance that our action will be outside of the allowed action space. We deal with this by clipping the actions to be within the allowed range (in this case between -1 and 1).

See the function `prepare_mujoco_env` within `part3_ppo/utils` (and read details 5-9 on the PPO page) for more information.

### Actor and Critic networks ([details #1-4](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Continuous%20actions%20via%20normal%20distributions))

Our actor and critic networks are quite similar to the ones we used for cartpole. They won't have shared architecture.

<details>
<summary>Question - can you see why it's less useful to have shared architecture in this case, relative to the case of Atari?</summary>

The point of the shared architecture in Atari was that it allowed our critic and actor to perform **feature extraction**, i.e. the early part of the network (which was fed the raw pixel input) generated a high-level representation of the state, which was then fed into the actor and critic heads. But for CartPole and for MuJoCo, we have a very small observation space (4 discrete values in the case of CartPole, 11 for the Hopper in MuJoCo), so there's no feature extraction necessary.

</details>

The only difference will be in the actor network. There will be an `actor_mu` and `actor_log_sigma` network. The `actor_mu` will have exactly the same architecture as the CartPole actor network, and it will output a vector used as the mean of our normal distribution. The `actor_log_sigma` network will just be a bias, since the standard deviation is **state-independent** ([detail #2](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=State%2Dindependent%20log%20standard%20deviation)).

Because of this extra complexity, we'll create a class for our actor and critic networks.

### Exercise - implement `Actor` and `Critic`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

As discussed, the architecture of `actor_mu` is identical to your cartpole actor network, and the critic is identical. The only difference is the addition of `actor_log_sigma`, which you should initialize as an `nn.Parameter` object of shape `(1, num_actions)`.

Your `Actor` class's forward function should return a tuple of `(mu, sigma, dist)`, where `mu` and `sigma` are the parameters of the normal distribution, and `dist` was created from these values using `torch.distributions.Normal`.

<details>
<summary>Why do we use <code>log_sigma</code> rather than just outputting <code>sigma</code> ?</summary>

We have our network output `log_sigma` rather than `sigma` because the standard deviation is always positive. If we learn the log standard deviation rather than the standard deviation, then we can treat it just like a regular learned weight.
</details>

Tip - when creating your distribution, you can use the `broadcast_to` tensor method, so that your standard deviation and mean are the same shape.

We've given you the function `get_actor_and_critic_mujoco` (which is called when you call `get_actor_and_critic` with `mode="mujoco"`). All you need to do is fill in the `Actor` and `Critic` classes.

In [None]:
class Critic(nn.Module):
    def __init__(self, num_obs):
        super().__init__()
        # raise NotImplementedError()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(num_obs, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std = 1.0)
        )

    def forward(self, obs) -> Tensor:
        # raise NotImplementedError()
        return self.critic(obs)
        


class Actor(nn.Module):
    actor_mu: nn.Sequential
    actor_log_sigma: nn.Parameter

    def __init__(self, num_obs, num_actions):
        super().__init__()
        # raise NotImplementedError()
        self.actor_mu = nn.Sequential(
            layer_init(nn.Linear(num_obs, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, num_actions), std = 0.01)
        )
        self.actor_log_sigma = nn.Params(t.zeros(1, num_actions))

    def forward(self, obs) -> tuple[Tensor, Tensor, t.distributions.Normal]:
        # raise NotImplementedError()
        mu = self.actor_mu(obs)
        sigma = t.exp(self.actor_log_sigma(obs)).broadcast_to(mu.shape) # why not einops.rearrange
        dist = t.distributions.Normal(mu, sigma)
        return mu, sigma, dist
def get_actor_and_critic_mujoco(num_obs: int, num_actions: int):
    """
    Returns (actor, critic) in the "classic-control" case, according to description above.
    """
    return Actor(num_obs, num_actions), Critic(num_obs)


tests.test_get_actor_and_critic(get_actor_and_critic, mode="mujoco")


You appear to be missing MuJoCo.  We expected to find the file here: /root/.mujoco/mujoco210

This package only provides python bindings, the library must be installed separately.

Please follow the instructions on the README to install MuJoCo

    https://github.com/openai/mujoco-py#install-mujoco

Which can be downloaded from the website

    https://www.roboti.us/index.html



Exception: 
You appear to be missing MuJoCo.  We expected to find the file here: /root/.mujoco/mujoco210

This package only provides python bindings, the library must be installed separately.

Please follow the instructions on the README to install MuJoCo

    https://github.com/openai/mujoco-py#install-mujoco

Which can be downloaded from the website

    https://www.roboti.us/index.html


<details><summary>Solution</summary>

```python
class Critic(nn.Module):
    def __init__(self, num_obs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(num_obs, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )

    def forward(self, obs) -> Tensor:
        value = self.critic(obs)
        return value


class Actor(nn.Module):
    actor_mu: nn.Sequential
    actor_log_sigma: nn.Parameter

    def __init__(self, num_obs, num_actions):
        super().__init__()
        self.actor_mu = nn.Sequential(
            layer_init(nn.Linear(num_obs, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, num_actions), std=0.01),
        )
        self.actor_log_sigma = nn.Parameter(t.zeros(1, num_actions))

    def forward(self, obs) -> tuple[Tensor, Tensor, t.distributions.Normal]:
        mu = self.actor_mu(obs)
        sigma = t.exp(self.actor_log_sigma).broadcast_to(mu.shape)
        dist = t.distributions.Normal(mu, sigma)
        return mu, sigma, dist
```
</details>

### Exercise - additional rewrites

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-25 minutes on this exercise.
> ```

There are a few more rewrites you'll need for continuous action spaces, which is why we recommend that you create a new solutions file for this part (like we've done with `solutions.py` and `solutions_cts.py`).

You'll need to make the following changes:

#### Logprobs and entropy

Rather than `probs = Categorical(logits=logits)` as your distribution (which you sample from & pass into your loss functions), you'll just use `dist` as your distribution. Methods like `.logprobs(action)` and `.entropy()` will work on `dist` just like they did on `probs`.

Note that these two methods will return objects of shape `(batch_size, action_shape)` (e.g. for Hopper the last dimension will be 3). We treat the action components as independent ([detail #4](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Independent%20action%20components)), meaning **we take a product of the probabilities, so we sum the logprobs / entropies**. For example:

$$
\begin{aligned}
\operatorname{prob}\left(a_t\right)&=\operatorname{prob}\left(a_t^1\right) \cdot \operatorname{prob}\left(a_t^2\right) \\
\log\left(a_t\right)&=\log\left(a_t^1\right) + \log\left(a_t^2\right)
\end{aligned}
$$

So you'll need to sum logprobs and entropy over the last dimension. The logprobs value that you add to the replay memory should be summed over (because you don't need the individual logprobs, you only need the logprob of the action as a whole).

#### Logging

You should log `mu` and `sigma` during the learning phase.

Below, we've given you a template for all the things you'll need to change (with new class & function names so they don't overwrite the previous versions), however if you prefer you can just rewrite your previous classes & functions in a way indicated by the code we've given you below.

In [None]:
class PPOAgentCts(PPOAgent):
    def play_step(self) -> list[dict]:
        """
        Changes required:
            - actor returns (mu, sigma, dist), with dist used to sample actions
            - logprobs need to be summed over action space
        """
        obs = self.next_obs
        terminated = self.next_terminated

        with t.inference_mode():
            # CHANGED: actor returns (mu, sigma, dist), with dist used to sample actions
            mu, sigma, dist = self.actor.forward(obs)
        actions = dist.sample()

        next_obs, rewards, next_terminated, next_truncated, infos = self.envs.step(actions.cpu().numpy())

        # CHANGED: logprobs need to be summed over action space
        logprobs = dist.log_prob(actions).sum(-1).cpu().numpy()
        with t.inference_mode():
            values = self.critic(obs).flatten().cpu().numpy()
        self.memory.add(obs.cpu().numpy(), actions.cpu().numpy(), logprobs, values, rewards, terminated.cpu().numpy())

        self.next_obs = t.from_numpy(next_obs).to(device, dtype=t.float)
        self.next_terminated = t.from_numpy(next_terminated).to(device, dtype=t.float)

        self.step += self.envs.num_envs
        return infos


def calc_clipped_surrogate_objective_cts(
    dist: t.distributions.Normal,
    mb_action: Int[Tensor, "minibatch_size *action_shape"],
    mb_advantages: Float[Tensor, "minibatch_size"],
    mb_logprobs: Float[Tensor, "minibatch_size"],
    clip_coef: float,
    eps: float = 1e-8,
) -> Float[Tensor, ""]:
    """
    Changes required:
        - logprobs need to be summed over action space
    """
    assert (mb_action.shape[0],) == mb_advantages.shape == mb_logprobs.shape

    # CHANGED: logprobs need to be summed over action space
    logits_diff = dist.log_prob(mb_action).sum(-1) - mb_logprobs

    r_theta = t.exp(logits_diff)

    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    non_clipped = r_theta * mb_advantages
    clipped = t.clip(r_theta, 1 - clip_coef, 1 + clip_coef) * mb_advantages

    return t.minimum(non_clipped, clipped).mean()


def calc_entropy_bonus_cts(dist: t.distributions.Normal, ent_coef: float):
    """
    Changes required:
        - entropy needs to be summed over action space before taking mean
    """
    # CHANGED: sum over first dim before taking mean
    return ent_coef * dist.entropy().sum(-1).mean()


class PPOTrainerCts(PPOTrainer):
    def __init__(self, args: PPOArgs):
        super().__init__(args)
        self.agent = PPOAgentCts(self.envs, self.actor, self.critic, self.memory)

    def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
        """
        Changes required:
            - actor returns (mu, sigma, dist), with dist used for loss functions (rather than getting dist from logits)
            - objective function calculated using new `_cts` functions defined above
            - newlogprob (for logging) needs to be summed over action space
            - mu and sigma should be logged
        """
        # CHANGED: actor returns (mu, sigma, dist), with dist used for loss functions (rather than getting dist from logits)
        mu, sigma, dist = self.agent.actor(minibatch.obs)
        values = self.agent.critic(minibatch.obs).squeeze()

        # CHANGED: objective function calculated using new `_cts` functions defined above
        clipped_surrogate_objective = calc_clipped_surrogate_objective_cts(
            dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, self.args.clip_coef
        )
        value_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef)
        entropy_bonus = calc_entropy_bonus_cts(dist, self.args.ent_coef)
        total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus

        with t.inference_mode():
            # CHANGED: newlogprob (for logging) needs to be summed over action space
            newlogprob = dist.log_prob(minibatch.actions).sum(-1)
            logratio = newlogprob - minibatch.logprobs
            ratio = logratio.exp()
            approx_kl = (ratio - 1 - logratio).mean().item()
            clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
        if self.args.use_wandb:
            wandb.log(
                dict(
                    total_steps=self.agent.step,
                    values=values.mean().item(),
                    lr=self.scheduler.optimizer.param_groups[0]["lr"],
                    value_loss=value_loss.item(),
                    clipped_surrogate_objective=clipped_surrogate_objective.item(),
                    entropy=entropy_bonus.item(),
                    approx_kl=approx_kl,
                    clipfrac=np.mean(clipfracs),
                    # CHANGED: mu and sigma should be logged
                    mu=mu.mean().item(),
                    sigma=sigma.mean().item(),
                ),
                step=self.agent.step,
            )

        return total_objective_function

<!-- <details><summary>Solution</summary>

```python
class PPOAgentCts(PPOAgent):
    def play_step(self) -> list[dict]:
        """
        Changes required:
            - actor returns (mu, sigma, dist), with dist used to sample actions
            - logprobs need to be summed over action space
        """
        obs = self.next_obs
        terminated = self.next_terminated

        with t.inference_mode():
            # CHANGED: actor returns (mu, sigma, dist), with dist used to sample actions
            mu, sigma, dist = self.actor.forward(obs)
        actions = dist.sample()

        next_obs, rewards, next_terminated, next_truncated, infos = self.envs.step(actions.cpu().numpy())

        # CHANGED: logprobs need to be summed over action space
        logprobs = dist.log_prob(actions).sum(-1).cpu().numpy()
        with t.inference_mode():
            values = self.critic(obs).flatten().cpu().numpy()
        self.memory.add(obs.cpu().numpy(), actions.cpu().numpy(), logprobs, values, rewards, terminated.cpu().numpy())

        self.next_obs = t.from_numpy(next_obs).to(device, dtype=t.float)
        self.next_terminated = t.from_numpy(next_terminated).to(device, dtype=t.float)

        self.step += self.envs.num_envs
        return infos


def calc_clipped_surrogate_objective_cts(
    dist: t.distributions.Normal,
    mb_action: Int[Tensor, "minibatch_size *action_shape"],
    mb_advantages: Float[Tensor, "minibatch_size"],
    mb_logprobs: Float[Tensor, "minibatch_size"],
    clip_coef: float,
    eps: float = 1e-8,
) -> Float[Tensor, ""]:
    """
    Changes required:
        - logprobs need to be summed over action space
    """
    assert (mb_action.shape[0],) == mb_advantages.shape == mb_logprobs.shape

    # CHANGED: logprobs need to be summed over action space
    logits_diff = dist.log_prob(mb_action).sum(-1) - mb_logprobs

    r_theta = t.exp(logits_diff)

    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    non_clipped = r_theta * mb_advantages
    clipped = t.clip(r_theta, 1 - clip_coef, 1 + clip_coef) * mb_advantages

    return t.minimum(non_clipped, clipped).mean()


def calc_entropy_bonus_cts(dist: t.distributions.Normal, ent_coef: float):
    """
    Changes required:
        - entropy needs to be summed over action space before taking mean
    """
    # CHANGED: sum over first dim before taking mean
    return ent_coef * dist.entropy().sum(-1).mean()


class PPOTrainerCts(PPOTrainer):
    def __init__(self, args: PPOArgs):
        super().__init__(args)
        self.agent = PPOAgentCts(self.envs, self.actor, self.critic, self.memory)

    def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
        """
        Changes required:
            - actor returns (mu, sigma, dist), with dist used for loss functions (rather than getting dist from logits)
            - objective function calculated using new `_cts` functions defined above
            - newlogprob (for logging) needs to be summed over action space
            - mu and sigma should be logged
        """
        # CHANGED: actor returns (mu, sigma, dist), with dist used for loss functions (rather than getting dist from logits)
        mu, sigma, dist = self.agent.actor(minibatch.obs)
        values = self.agent.critic(minibatch.obs).squeeze()

        # CHANGED: objective function calculated using new `_cts` functions defined above
        clipped_surrogate_objective = calc_clipped_surrogate_objective_cts(
            dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, self.args.clip_coef
        )
        value_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef)
        entropy_bonus = calc_entropy_bonus_cts(dist, self.args.ent_coef)
        total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus

        with t.inference_mode():
            # CHANGED: newlogprob (for logging) needs to be summed over action space
            newlogprob = dist.log_prob(minibatch.actions).sum(-1)
            logratio = newlogprob - minibatch.logprobs
            ratio = logratio.exp()
            approx_kl = (ratio - 1 - logratio).mean().item()
            clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
        if self.args.use_wandb:
            wandb.log(
                dict(
                    total_steps=self.agent.step,
                    values=values.mean().item(),
                    lr=self.scheduler.optimizer.param_groups[0]["lr"],
                    value_loss=value_loss.item(),
                    clipped_surrogate_objective=clipped_surrogate_objective.item(),
                    entropy=entropy_bonus.item(),
                    approx_kl=approx_kl,
                    clipfrac=np.mean(clipfracs),
                    # CHANGED: mu and sigma should be logged
                    mu=mu.mean().item(),
                    sigma=sigma.mean().item(),
                ),
                step=self.agent.step,
            )

        return total_objective_function
```
</details> -->

## Training MuJoCo

Now, you should be ready to run your training loop! We recommend using the following parameters, to match the original implmentation which the [37 Implementational Details](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details) post is based on (but you can experiment with different values if you like).

In [None]:
args = PPOArgs(
    env_id="Hopper-v4",
    wandb_project_name="PPOMuJoCo",
    use_wandb=True,
    mode="mujoco",
    lr=3e-4,
    ent_coef=0.0,
    num_minibatches=32,
    num_steps_per_rollout=2048,
    num_envs=1,
    video_log_freq=75,
)
trainer = PPOTrainerCts(args)
trainer.train()

You should expect the reward to increase pretty fast initially and then plateau once the agent learns the solution "kick off for a very large initial jump, and don't think about landing". Eventually the agent gets past this plateau, and learns to land successfully without immediately falling over. Once it's at the point where it can string two jumps together, your reward should start increasing much faster.

Here is a video produced from a successful run, using the parameters above:

<video width="400" height="420" controls>
<source src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/media-23/2305.mp4" type="video/mp4">
</video>

and here's the corresponding plot of episode lengths:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/wandb-mujoco-lengths.png" width="550">

Although we've used `Hopper-v4` in these examples, you might also want to try `InvertedPendulum-v4` (docs [here](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/)). It's a much easier environment to solve, and it's a good way to check that your implementation is working (after all if it worked for CartPole then it should work here - in fact your inverted pendulum agent should converge to a perfect solution almost instantly, no reward shaping required). You can check out the other MuJoCo environments [here](https://gymnasium.farama.org/environments/mujoco/).

# ☆ Bonus

## Trust Region Methods

Some versions of the PPO algorithm use a slightly different objective function. Rather than our clipped surrogate objective, they use constrained optimization (maximising the surrogate objective subject to a restriction on the [KL divergence](https://www.lesswrong.com/posts/no5jDTut5Byjqb4j5/six-and-a-half-intuitions-for-kl-divergence) between the old and new policies).

$$
\begin{array}{ll}
\underset{\theta}{\operatorname{maximize}} & \hat{\mathbb{E}}_t\left[\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old}}}\left(a_t \mid s_t\right)} \hat{A}_t\right] \\
\text { subject to } & \hat{\mathbb{E}}_t\left[\mathrm{KL}\left[\pi_{\theta_{\text {old}}}\left(\cdot \mid s_t\right), \pi_\theta\left(\cdot \mid s_t\right)\right]\right] \leq \delta
\end{array}
$$

The intuition behind this is similar to the clipped surrogate objective. For our clipped objective, we made sure the model wasn't rewarded for deviating from its old policy beyond a certain point (which encourages small updates). Adding an explicit KL constraint accomplishes something similar, because it forces the model to closely adhere to the old policy. For more on KL-divergence and why it's a principled measure, see [this post](https://www.lesswrong.com/posts/no5jDTut5Byjqb4j5/six-and-a-half-intuitions-for-kl-divergence). We call these algorithms trust-region methods because they incentivise the model to stay in a **trusted region of policy space**, i.e. close to the old policy (where we can be more confident in our results).

The theory behind TRPO actually suggests the following variant - turning the strict constraint into a penalty term, which you should find easier to implement:

$$
\underset{\theta}{\operatorname{maximize}} \, \hat{\mathbb{E}}_t\left[\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old}}}\left(a_t \mid s_t\right)} \hat{A}_t-\beta \mathrm{KL}\left[\pi_{\theta_{\text {old}}}\left(\cdot \mid s_t\right), \pi_\theta\left(\cdot \mid s_t\right)\right]\right]
$$

Rather than forcing the new policy to stay close to the previous policy, this adds a penalty term which incentivises this behaviour (in fact, there is a 1-1 correspondence between constrained optimization problems and the corresponding unconstrained version).

Can you implement this? Does this approach work better than the clipped surrogate objective? What values of $\beta$ work best?

Tip - you can calculate KL divergence using the PyTorch [KL Divergence function](https://pytorch.org/docs/stable/distributions.html#module-torch.distributions.kl). You could also try the approximate version, as described in [detail #12](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#:~:text=Debug%20variables) of the "37 Implementational Details" post.

## Long-term replay memory

Above, we discussed the problem of **catastrophic forgetting** (where the agent forgets how to recover from bad behaviour, because the memory only contains good behaviour). One way to fix this is to have a long-term replay memory, for instance:

* (simple version) You reserve e.g. 10% of your buffer for experiences generated at the start of training.
* (complex version) You design a custom scheduled method for removing experiences from memory, so that you always have a mix of old and new experiences.

Can you implement one of these, and does it fix the catastrophic forgetting problem (without needing to use reward shaping)?

## Vectorized Advantage Calculation

Try optimizing away the for-loop in your advantage calculation. It's tricky (and quite messy), so an easier version of this is: find a vectorized calculation and try to explain what it does.

<details>
<summary>Hint (for your own implementation)</summary>

*(Assume `num_envs=1` for simplicity)*

Construct a 2D boolean array from `dones`, where the `(i, j)`-th element of the array tells you whether the expression for the `i`-th advantage function should include rewards / values at timestep `j`. You can do this via careful use of `torch.cumsum`, `torch.triu`, and some rearranging.
</details>

<!-- There are solutions available in `solutions.py` (commented out). -->

## Other Discrete Environments

Two environments (supported by gym) which you might like to try are:

* [`Acrobot-v1`](https://gymnasium.farama.org/environments/classic_control/acrobot/) - this is one of the [Classic Control environments](https://gymnasium.farama.org/environments/classic_control/), and it's a bit harder to learn than cartpole.
* [`MountainCar-v0`](https://gymnasium.farama.org/environments/classic_control/mountain_car/) - this is one of the [Classic Control environments](https://gymnasium.farama.org/environments/classic_control/), and it's much harder to learn than cartpole. This is primarily because of **sparse rewards** (it's really hard to get to the top of the hill), so you'll definitely need reward shaping to get through it!
* [`LunarLander-v2`](https://gymnasium.farama.org/environments/box2d/lunar_lander/) - this is part of the [Box2d](https://gymnasium.farama.org/environments/box2d/) environments. It's a bit harder still, because of the added environmental complexity (physics like gravity and friction, and constraints like fuel conservatino). The reward is denser (with the agent receiving rewards for moving towards the landing pad and penalties for moving away or crashing), but the increased complexity makes it overall a harder problem to solve. You might have to perform hyperparameter sweeps to find the best implementation (you can go back and look at the syntax for hyperparameter sweeps [here](https://arena-ch0-fundamentals.streamlit.app/[0.4]_Optimization)). Also, [this page](https://pylessons.com/LunarLander-v2-PPO) might be a useful reference (although the details of their implementation differs from the one we used today). You can look at the hyperparameters they used.

## Continuous Action Spaces & Reward Shaping

The `MountainCar-v0` environment has discrete actions, but there's also a version `MountainCarContinuous-v0` with continuous action spaces. Implementing this will require a combination of the continuous action spaces you dealt with during the MuJoCo section, and the reward shaping you used during the CartPole exercises.

## Choose & build your own environment (e.g. Wordle)

You can also try choosing your own task, framing it as an RL problem, and adapting your PPO algorithm to solve it. For example, training an agent to play Wordle (or a relation like Semantle) might be a suitably difficult task. [This post](https://wandb.ai/andrewkho/wordle-solver/reports/Solving-Wordle-with-Reinforcement-Learning--VmlldzoxNTUzOTc4) gives a high level account of training an agent to play Wordle - they use DQN, but they don't go too deep into the technical details (and it's possible that PPO would work better for this task).

## Minigrid envs / Procgen

There are many more exciting environments to play in, but generally they're going to require more compute and more optimization than we have time for today. If you want to try them out, some we recommend are:

- [Minimalistic Gridworld Environments](https://github.com/Farama-Foundation/gym-minigrid) - a fast gridworld environment for experiments with sparse rewards and natural language instruction.
- [microRTS](https://github.com/santiontanon/microrts) - a small real-time strategy game suitable for experimentation.
- [Megastep](https://andyljones.com/megastep/) - RL environment that runs fully on the GPU (fast!)
- [Procgen](https://github.com/openai/procgen) - A family of 16 procedurally generated gym environments to measure the ability for an agent to generalize. Optimized to run quickly on the CPU.
    - For this one, you might want to read [Jacob Hilton's online DL tutorial](https://github.com/jacobhilton/deep_learning_curriculum/blob/master/6-Reinforcement-Learning.md) (the RL chapter suggests implementing PPO on Procgen), and [Connor Kissane's solutions](https://github.com/ckkissane/deep_learning_curriculum/blob/master/solutions/6_Reinforcement_Learning.ipynb).

## Multi-Agent PPO

Multi-Agent PPO (MAPPO) is an extension of the standard PPO algorithm which trains multiple agents at once. It was first described in the paper [The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games](https://arxiv.org/abs/2103.01955). Can you implement MAPPO?