In [1]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch torchvision diffusers \
scikit-image scikit-video zarr numcodecs \
pygame pymunk gym shapely opencv-python

Python 3.10.12
Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl (766.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.7/766.7 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting diffusers
  Downloading diffusers-0.33.1-py3-none-any.whl (3.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m90.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hCollecting scikit-image
  Downloading scikit_image-0.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.8/14.8 MB[0m [31m84.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting scikit-video
 

In [1]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import torch
import torch.nn as nn
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#@markdown ### **Environment**
#@markdown Defines a PyMunk-based Push-T environment `PushTEnv`.
#@markdown
#@markdown **Goal**: push the gray T-block into the green area.
#@markdown
#@markdown Adapted from [Implicit Behavior Cloning](https://implicitbc.github.io/)


positive_y_is_up: bool = False
"""Make increasing values of y point upwards.

When True::

    y
    ^
    |      . (3, 3)
    |
    |   . (2, 2)
    |
    +------ > x

When False::

    +------ > x
    |
    |   . (2, 2)
    |
    |      . (3, 3)
    v
    y

"""

def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
    """Convenience method to convert pymunk coordinates to pygame surface
    local coordinates.

    Note that in case positive_y_is_up is False, this function wont actually do
    anything except converting the point to integers.
    """
    if positive_y_is_up:
        return round(p[0]), surface.get_height() - round(p[1])
    else:
        return round(p[0]), round(p[1])


def light_color(color: SpaceDebugColor):
    color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
    color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
    return color

class DrawOptions(pymunk.SpaceDebugDrawOptions):
    def __init__(self, surface: pygame.Surface) -> None:
        """Draw a pymunk.Space on a pygame.Surface object.

        Typical usage::

        >>> import pymunk
        >>> surface = pygame.Surface((10,10))
        >>> space = pymunk.Space()
        >>> options = pymunk.pygame_util.DrawOptions(surface)
        >>> space.debug_draw(options)

        You can control the color of a shape by setting shape.color to the color
        you want it drawn in::

        >>> c = pymunk.Circle(None, 10)
        >>> c.color = pygame.Color("pink")

        See pygame_util.demo.py for a full example

        Since pygame uses a coordiante system where y points down (in contrast
        to many other cases), you either have to make the physics simulation
        with Pymunk also behave in that way, or flip everything when you draw.

        The easiest is probably to just make the simulation behave the same
        way as Pygame does. In that way all coordinates used are in the same
        orientation and easy to reason about::

        >>> space = pymunk.Space()
        >>> space.gravity = (0, -1000)
        >>> body = pymunk.Body()
        >>> body.position = (0, 0) # will be positioned in the top left corner
        >>> space.debug_draw(options)

        To flip the drawing its possible to set the module property
        :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
        the simulation upside down before drawing::

        >>> positive_y_is_up = True
        >>> body = pymunk.Body()
        >>> body.position = (0, 0)
        >>> # Body will be position in bottom left corner

        :Parameters:
                surface : pygame.Surface
                    Surface that the objects will be drawn on
        """
        self.surface = surface
        super(DrawOptions, self).__init__()

    def draw_circle(
        self,
        pos: Vec2d,
        angle: float,
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p = to_pygame(pos, self.surface)

        pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
        pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)

        circle_edge = pos + Vec2d(radius, 0).rotated(angle)
        p2 = to_pygame(circle_edge, self.surface)
        line_r = 2 if radius > 20 else 1
        # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)

    def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])

    def draw_fat_segment(
        self,
        a: Tuple[float, float],
        b: Tuple[float, float],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        r = round(max(1, radius * 2))
        pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
        if r > 2:
            orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
            if orthog[0] == 0 and orthog[1] == 0:
                return
            scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
            orthog[0] = round(orthog[0] * scale)
            orthog[1] = round(orthog[1] * scale)
            points = [
                (p1[0] - orthog[0], p1[1] - orthog[1]),
                (p1[0] + orthog[0], p1[1] + orthog[1]),
                (p2[0] + orthog[0], p2[1] + orthog[1]),
                (p2[0] - orthog[0], p2[1] - orthog[1]),
            ]
            pygame.draw.polygon(self.surface, fill_color.as_int(), points)
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p1[0]), round(p1[1])),
                round(radius),
            )
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p2[0]), round(p2[1])),
                round(radius),
            )

    def draw_polygon(
        self,
        verts: Sequence[Tuple[float, float]],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        ps = [to_pygame(v, self.surface) for v in verts]
        ps += [ps[0]]

        radius = 2
        pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)

        if radius > 0:
            for i in range(len(verts)):
                a = verts[i]
                b = verts[(i + 1) % len(verts)]
                self.draw_fat_segment(a, b, radius, fill_color, fill_color)

    def draw_dot(
        self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
    ) -> None:
        p = to_pygame(pos, self.surface)
        pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)


def pymunk_to_shapely(body, shapes):
    geoms = list()
    for shape in shapes:
        if isinstance(shape, pymunk.shapes.Poly):
            verts = [body.local_to_world(v) for v in shape.get_vertices()]
            verts += [verts[0]]
            geoms.append(sg.Polygon(verts))
        else:
            raise RuntimeError(f'Unsupported shape type {type(shape)}')
    geom = sg.MultiPolygon(geoms)
    return geom

# env
class PushTEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
    reward_range = (0., 1.)

    def __init__(self,
            legacy=False,
            block_cog=None, damping=None,
            render_action=True,
            render_size=96,
            reset_to_state=None
        ):
        self._seed = None
        self.seed()
        self.window_size = ws = 512  # The size of the PyGame window
        self.render_size = render_size
        self.sim_hz = 100
        # Local controller params.
        self.k_p, self.k_v = 100, 20    # PD control.z
        self.control_hz = self.metadata['video.frames_per_second']
        # legcay set_state for data compatiblity
        self.legacy = legacy

        # agent_pos, block_pos, block_angle
        self.observation_space = spaces.Box(
            low=np.array([0,0,0,0,0], dtype=np.float64),
            high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64),
            shape=(5,),
            dtype=np.float64
        )

        # positional goal for agent
        self.action_space = spaces.Box(
            low=np.array([0,0], dtype=np.float64),
            high=np.array([ws,ws], dtype=np.float64),
            shape=(2,),
            dtype=np.float64
        )

        self.block_cog = block_cog
        self.damping = damping
        self.render_action = render_action

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
        self.screen = None

        self.space = None
        self.teleop = None
        self.render_buffer = None
        self.latest_action = None
        self.reset_to_state = reset_to_state

    def reset(self):
        seed = self._seed
        self._setup()
        if self.block_cog is not None:
            self.block.center_of_gravity = self.block_cog
        if self.damping is not None:
            self.space.damping = self.damping

        # use legacy RandomState for compatiblity
        state = self.reset_to_state
        if state is None:
            rs = np.random.RandomState(seed=seed)
            state = np.array([
                rs.randint(50, 450), rs.randint(50, 450),
                rs.randint(100, 400), rs.randint(100, 400),
                rs.randn() * 2 * np.pi - np.pi
                ])
        self._set_state(state)

        obs = self._get_obs()
        info = self._get_info()
        return obs, info

    def step(self, action):
        dt = 1.0 / self.sim_hz
        self.n_contact_points = 0
        n_steps = self.sim_hz // self.control_hz
        if action is not None:
            self.latest_action = action
            for i in range(n_steps):
                # Step PD control.
                # self.agent.velocity = self.k_p * (act - self.agent.position)    # P control works too.
                acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity)
                self.agent.velocity += acceleration * dt

                # Step physics.
                self.space.step(dt)

        # compute reward
        goal_body = self._get_goal_pose_body(self.goal_pose)
        goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
        block_geom = pymunk_to_shapely(self.block, self.block.shapes)

        intersection_area = goal_geom.intersection(block_geom).area
        goal_area = goal_geom.area
        coverage = intersection_area / goal_area
        reward = np.clip(coverage / self.success_threshold, 0, 1)
        done = coverage > self.success_threshold
        terminated = done
        truncated = done

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

    def render(self, mode):
        return self._render_frame(mode)

    def teleop_agent(self):
        TeleopAgent = collections.namedtuple('TeleopAgent', ['act'])
        def act(obs):
            act = None
            mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
            if self.teleop or (mouse_position - self.agent.position).length < 30:
                self.teleop = True
                act = mouse_position
            return act
        return TeleopAgent(act)

    def _get_obs(self):
        obs = np.array(
            tuple(self.agent.position) \
            + tuple(self.block.position) \
            + (self.block.angle % (2 * np.pi),))
        return obs

    def _get_goal_pose_body(self, pose):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (50, 100))
        body = pymunk.Body(mass, inertia)
        # preserving the legacy assignment order for compatibility
        # the order here dosn't matter somehow, maybe because CoM is aligned with body origin
        body.position = pose[:2].tolist()
        body.angle = pose[2]
        return body

    def _get_info(self):
        n_steps = self.sim_hz // self.control_hz
        n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
        info = {
            'pos_agent': np.array(self.agent.position),
            'vel_agent': np.array(self.agent.velocity),
            'block_pose': np.array(list(self.block.position) + [self.block.angle]),
            'goal_pose': self.goal_pose,
            'n_contacts': n_contact_points_per_step}
        return info

    def _render_frame(self, mode):

        if self.window is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        self.screen = canvas

        draw_options = DrawOptions(canvas)

        # Draw goal pose.
        goal_body = self._get_goal_pose_body(self.goal_pose)
        for shape in self.block.shapes:
            goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()]
            goal_points += [goal_points[0]]
            pygame.draw.polygon(canvas, self.goal_color, goal_points)

        # Draw agent and block.
        self.space.debug_draw(draw_options)

        if mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # the clock is aleady ticked during in step for "human"


        img = np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        img = cv2.resize(img, (self.render_size, self.render_size))
        if self.render_action:
            if self.render_action and (self.latest_action is not None):
                action = np.array(self.latest_action)
                coord = (action / 512 * 96).astype(np.int32)
                marker_size = int(8/96*self.render_size)
                thickness = int(1/96*self.render_size)
                cv2.drawMarker(img, coord,
                    color=(255,0,0), markerType=cv2.MARKER_CROSS,
                    markerSize=marker_size, thickness=thickness)
        return img


    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def seed(self, seed=None):
        if seed is None:
            seed = np.random.randint(0,25536)
        self._seed = seed
        self.np_random = np.random.default_rng(seed)

    def _handle_collision(self, arbiter, space, data):
        self.n_contact_points += len(arbiter.contact_point_set.points)

    def _set_state(self, state):
        if isinstance(state, np.ndarray):
            state = state.tolist()
        pos_agent = state[:2]
        pos_block = state[2:4]
        rot_block = state[4]
        self.agent.position = pos_agent
        # setting angle rotates with respect to center of mass
        # therefore will modify the geometric position
        # if not the same as CoM
        # therefore should be modified first.
        if self.legacy:
            # for compatiblity with legacy data
            self.block.position = pos_block
            self.block.angle = rot_block
        else:
            self.block.angle = rot_block
            self.block.position = pos_block

        # Run physics to take effect
        self.space.step(1.0 / self.sim_hz)

    def _set_state_local(self, state_local):
        agent_pos_local = state_local[:2]
        block_pose_local = state_local[2:]
        tf_img_obj = st.AffineTransform(
            translation=self.goal_pose[:2],
            rotation=self.goal_pose[2])
        tf_obj_new = st.AffineTransform(
            translation=block_pose_local[:2],
            rotation=block_pose_local[2]
        )
        tf_img_new = st.AffineTransform(
            matrix=tf_img_obj.params @ tf_obj_new.params
        )
        agent_pos_new = tf_img_new(agent_pos_local)
        new_state = np.array(
            list(agent_pos_new[0]) + list(tf_img_new.translation) \
                + [tf_img_new.rotation])
        self._set_state(new_state)
        return new_state

    def _setup(self):
        self.space = pymunk.Space()
        self.space.gravity = 0, 0
        self.space.damping = 0
        self.teleop = False
        self.render_buffer = list()

        # Add walls.
        walls = [
            self._add_segment((5, 506), (5, 5), 2),
            self._add_segment((5, 5), (506, 5), 2),
            self._add_segment((506, 5), (506, 506), 2),
            self._add_segment((5, 506), (506, 506), 2)
        ]
        self.space.add(*walls)

        # Add agent, block, and goal zone.
        self.agent = self.add_circle((256, 400), 15)
        self.block = self.add_tee((256, 300), 0)
        self.goal_color = pygame.Color('LightGreen')
        self.goal_pose = np.array([256,256,np.pi/4])  # x, y, theta (in radians)

        # Add collision handeling
        self.collision_handeler = self.space.add_collision_handler(0, 0)
        self.collision_handeler.post_solve = self._handle_collision
        self.n_contact_points = 0

        self.max_score = 50 * 100
        self.success_threshold = 0.95    # 95% coverage.

    def _add_segment(self, a, b, radius):
        shape = pymunk.Segment(self.space.static_body, a, b, radius)
        shape.color = pygame.Color('LightGray')    # https://htmlcolorcodes.com/color-names
        return shape

    def add_circle(self, position, radius):
        body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
        body.position = position
        body.friction = 1
        shape = pymunk.Circle(body, radius)
        shape.color = pygame.Color('RoyalBlue')
        self.space.add(body, shape)
        return body

    def add_box(self, position, height, width):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (height, width))
        body = pymunk.Body(mass, inertia)
        body.position = position
        shape = pymunk.Poly.create_box(body, (height, width))
        shape.color = pygame.Color('LightSlateGray')
        self.space.add(body, shape)
        return body

    def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()):
        mass = 1
        length = 4
        vertices1 = [(-length*scale/2, scale),
                                 ( length*scale/2, scale),
                                 ( length*scale/2, 0),
                                 (-length*scale/2, 0)]
        inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
        vertices2 = [(-scale/2, scale),
                                 (-scale/2, length*scale),
                                 ( scale/2, length*scale),
                                 ( scale/2, scale)]
        inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
        body = pymunk.Body(mass, inertia1 + inertia2)
        shape1 = pymunk.Poly(body, vertices1)
        shape2 = pymunk.Poly(body, vertices2)
        shape1.color = pygame.Color(color)
        shape2.color = pygame.Color(color)
        shape1.filter = pymunk.ShapeFilter(mask=mask)
        shape2.filter = pymunk.ShapeFilter(mask=mask)
        body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
        body.position = position
        body.angle = angle
        body.friction = 1
        self.space.add(body, shape1, shape2)
        return body


In [3]:
# from huggingface_hub.utils import IGNORE_GIT_FOLDER_PATTERNS
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, IGNORE_GIT_FOLDER_PATTERNS = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("Obs: ", repr(obs))
    print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
    print("Action: ", repr(action))
    print("Action:   [target_agent_x, target_agent_y]")

Obs:  array([137.0963, 130.7458, 292.    , 351.    ,   2.9196])
Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]
Action:  array([137.3257, 153.9614])
Action:   [target_agent_x, target_agent_y]


In [4]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTStateDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data (obs, action) from a zarr storage
#@markdown - Normalizes each dimension of obs and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `obs`: shape (obs_horizon, obs_dim)
#@markdown  - key `action`: shape (pred_horizon, action_dim)

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0),
        'std': np.std(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class PushTStateDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path,
                 pred_horizon, obs_horizon, action_horizon):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
        # All demonstration episodes are concatinated in the first dimension N
        train_data = {
            # (N, action_dim)
            'action': dataset_root['data']['action'][:],
            # (N, obs_dim)
            'obs': dataset_root['data']['state'][:]
        }
        # Marks one-past the last index for each episode
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            # add padding such that each timestep in the dataset are seen
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])
            stats[key+'_normalized'] = get_data_stats(normalized_train_data[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        # all possible segments of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['obs'] = nsample['obs'][:self.obs_horizon,:]
        return nsample


In [7]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTStateDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

batch['obs'].shape: torch.Size([256, 2, 5])
batch['action'].shape torch.Size([256, 16, 2])


In [24]:
# dataset size
print(len(dataset))
stats['action_normalized']['std'].mean()

24208


np.float32(0.40121755)

In [25]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

# this is the unet for the robotics problem, modify this to use inductive moment matching
# IMM

'''
η_t = t/α_t.
OT-FM schedule: α_t = 1-t, σ_t = t
x_t = α_t·x + σ_t·ε # ε ~ N(0, I)
cout(t) = −t·σ_d

'''

import math
import torch
import torch.nn as nn
from typing import Union
from tqdm import tqdm
class IMMloss(nn.Module):
    """
    IMM loss function using the Laplace kernel.
    """

    def __init__(self, obs_horizon, pred_horizon, num_particles):
        super(IMMloss, self).__init__()
        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon
        self.num_particles = num_particles
        
    def laplace_kernel(self, x, y, w_scale, eps=0.006, dim_normalize=True):
        """
        Laplace kernel: exp(w_scale * max(||x-y||_2, eps)/D)
        
        Args:
            x, y: input tensors
            w_scale: scaling factor (time-dependent)
            eps: small constant to avoid undefined gradients
            dim_normalize: whether to normalize by dimensionality D
        """
        D = x.shape[-1] if dim_normalize else 1.0
        distance = torch.norm(x - y, p=2, dim=-1)
        # Apply max to avoid zero gradients
        distance = torch.clamp(distance, min=eps)
        return torch.exp(-w_scale * distance / D)
    
    def forward(self, model_outputs, time_weights, stop_gradient_outputs):
        """
        Compute the IMM loss for a batch of model outputs.
        
        Args:
            model_outputs: Dictionary containing:
                - ys_t: outputs from time t to s [B, self.pred_horizon, self.obs_horizon]
                - ys_r: outputs from time r to s [B, self.pred_horizon, self.obs_horizon]
                - w_scale: time-dependent scaling factors [B]
            time_weights: w(s,t) weights [B/M]
            stop_gradient_outputs: Optional dictionary with same structure as model_outputs
                                   containing the detached outputs (θ-)
        """
        
        # Extract batch size and reshape for group processing
        batch_size = model_outputs['ys_t'].shape[0]
        M = self.num_particles
        num_groups = batch_size // M
        
        # Flatten pred_horizon and obs_horizon dimensions before reshaping
        # Reshape tensors to [num_groups, M, D]
        ys_t = model_outputs['ys_t'].reshape(batch_size, -1).reshape(num_groups, M, -1)
        ys_r_stop = stop_gradient_outputs['ys_r'].reshape(batch_size, -1).reshape(num_groups, M, -1)
        w_scale = model_outputs['w_scale'].reshape(num_groups, M)
        

        
        # Reshape time weights to [num_groups] by extracting the first element of each group
        time_weights = time_weights.reshape(num_groups, M)[:,0].reshape(-1)
        
        total_loss = 0.0
        for i in range(num_groups):
            group_loss = 0.0
            
            # Compute the kernel matrices
            for j in range(M):
                for k in range(M):
                    # First term: k(f_s,t^θ(x_t^(i,j)), f_s,t^θ(x_t^(i,k)))
                    term1 = self.laplace_kernel(ys_t[i, j], ys_t[i, k], w_scale[i, j])
                    
                    # Second term: k(f_s,r^θ-(x_r^(i,j)), f_s,r^θ-(x_r^(i,k)))
                    term2 = self.laplace_kernel(ys_r_stop[i, j], ys_r_stop[i, k], w_scale[i, j])
                    
                    # Third term: -2k(f_s,t^θ(x_t^(i,j)), f_s,r^θ-(x_r^(i,k)))
                    term3 = -2.0 * self.laplace_kernel(ys_t[i, j], ys_r_stop[i, k], w_scale[i, j])
                    
                    # Sum up the terms
                    group_loss += term1 + term2 + term3
            
            # Apply time-dependent weighting
            group_loss = group_loss * time_weights[i] / (M * M)
            total_loss += group_loss
        
        # Average over the number of groups
        return total_loss / num_groups

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        
        # Second encoder for timestep s (for IMM)
        diffusion_step_encoder_s = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
            
        # Total conditioning dimensions: t embedding + s embedding + global conditioning
        cond_dim = dsed * 2 + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.diffusion_step_encoder_s = diffusion_step_encoder_s
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float],
            timestep_s: Union[torch.Tensor, float],
            global_cond):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step t
        timestep_s: (B,) or int, diffusion step s (for IMM)
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time embedding for t
        # timesteps = timestep
        # if not torch.is_tensor(timesteps):
        #     # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
        #     timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
        #     timesteps = timesteps[None].to(sample.device)
        # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        # timesteps = timesteps.expand(sample.shape[0])
        
        # Get time embedding for t
        t_emb = self.diffusion_step_encoder(timestep)
        
        # 2. time embedding for s
        s_emb = self.diffusion_step_encoder_s(timestep_s)
        
        # Combine t and s embeddings
        global_feature = torch.cat([t_emb, s_emb], dim=-1)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

class RoboIMM:
    """
    Simplified class for IMM with the robotics UNet.
    """
    def __init__(
        self,
        model,
        sigma_data,
        obs_horizon,
        pred_horizon,
        num_particles
    ):
        """
        Initialize the IMM sampler.
        
        Args:
            model: The ConditionalUnet1D model
            sigma_data: Data standard deviation
        """
        self.model = model
        self.sigma_data = sigma_data
        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon
        self.num_particles = num_particles

        self.loss = IMMloss(obs_horizon=obs_horizon, pred_horizon=pred_horizon, num_particles=num_particles)

    def get_alpha_sigma(self, t):
        """Get alpha and sigma values for time t."""
        # Using the "flow matching" schedule
        alpha_t = (1 - t)
        sigma_t = t
        return alpha_t, sigma_t
    
    # def euler_step(self, yt, pred, t, s):
    #     """Euler step for flow matching."""
    #     return yt - (t - s) * self.sigma_data * pred
    
    # def edm_step(self, yt, pred, t, s):
    #     """EDM step for sampling."""
    #     alpha_t, sigma_t = self.get_alpha_sigma(t)
    #     alpha_s, sigma_s = self.get_alpha_sigma(s)
         
    #     c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2)
    #     c_out = -(alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data
        
    #     return c_skip * yt + c_out * pred
    
    def ddim(self, yt, y, t, s, noise=None):
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        alpha_s, sigma_s = self.get_alpha_sigma(s)

        alpha_s = alpha_s.reshape(-1,1,1)
        sigma_s = sigma_s.reshape(-1,1,1)
        alpha_t = alpha_t.reshape(-1,1,1)
        sigma_t = sigma_t.reshape(-1,1,1)
        
        if noise is None: 
            ys = (alpha_s -   alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt
        else:
            ys = alpha_s * y + sigma_s * noise
        return ys
    
    def sample(self, shape, steps=20, global_cond=None, sampling_method="ddim"):
        """
        Generate samples using IMM sampling.
        
        Args:
            shape: Shape of the samples to generate
            steps: Number of sampling steps
            global_cond: Global conditioning
            sampling_method: "ddim"
            
        Returns:
            Generated samples
        """
        device = next(self.model.parameters()).device
        
        # Initialize with noise
        x = torch.randn(shape, device=device) * self.sigma_data
        
        # Define time steps (uniform steps from 1 to 0)
        times = torch.linspace(0.994, 0.006, steps + 1, device=device)
        
        for i in range(steps):
            t = times[i]
            s = times[i + 1]
            
            # Create batched time tensors
            t_batch = torch.full((shape[0],), t, device=device)
            s_batch = torch.full((shape[0],), s, device=device)
            
            # Run model forward
            with torch.no_grad():
                pred = self.model(x, t_batch, s_batch, global_cond)
            
            # Apply sampling function based on method
            if sampling_method == "ddim":
                x = self.ddim(x, pred, t_batch.view(-1, 1, 1), s_batch.view(-1, 1, 1))
            else:
                raise ValueError(f"Unknown sampling method: {sampling_method}")
        
        return x

    def calculate_weights(self, s_times, t_times):
        """
        Calculate the time-dependent weighting function w(s,t)
        """
        b = 5  # Hyperparameter from paper
        a = 1    # Hyperparameter from paper (a ∈ {1, 2})

        alpha_t, sigma_t = self.get_alpha_sigma(t_times)
        
        # Calculate log-SNR values
        log_snr_t = 2 * torch.log(alpha_t / sigma_t)
        dlog_snr_t = 2 / (torch.square(t_times) - t_times)
        
        # Calculate coefficient based on equation 13
        sigmoid_term = torch.sigmoid(b - log_snr_t)
        
        snr_term = (alpha_t ** a) / (alpha_t ** 2 + sigma_t ** 2)
        
        return 0.5 * sigmoid_term * -1.0 * dlog_snr_t * snr_term

    def train(self, train_loader, val_loader, num_epochs=100, lr=1e-4, device='cuda'):
        """
        Train the model.
        """

        # Standard ADAM optimizer
        # Note that EMA parametesr are not optimized
        optimizer = torch.optim.AdamW(
            params=self.model.parameters(),
            lr=lr, weight_decay=0)

        # Cosine LR schedule
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

        self.model.train()
        
        with tqdm(range(num_epochs), desc='Epoch') as tglobal:
            # epoch loop
            for epoch_idx in tglobal:
                # batch loop
                with tqdm(train_loader, desc='Batch', leave=False) as tepoch:
                    for batch_idx, batch in enumerate(tepoch):
                        nobs = batch['obs'].to(device)
                        naction = batch['action'].to(device)
                        B = nobs.shape[0]

                        num_groups = B // self.num_particles
                        s_times = torch.rand(num_groups, device=device)
                        t_times = s_times + (1 - s_times) * torch.rand(num_groups, device=device)
                        r_times = s_times + (t_times - s_times) * torch.rand(num_groups, device=device)

                        # times need to be shape (B,), currently they are shape (num_groups,)
                        s_times = s_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)
                        t_times = t_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)
                        r_times = r_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)

                        noise = torch.randn_like(naction) * self.sigma_data

                        x_t = self.ddim(yt=None, y=naction, t=t_times, s=s_times, noise=noise)
                        x_r = self.ddim(yt=x_t, y=naction, t=t_times, s=r_times)
                                                
                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:self.obs_horizon,:]
                        # (B, obs_horizon * obs_dim)
                        obs_cond = obs_cond.flatten(start_dim=1)

                        optimizer.zero_grad()
                        pred_grad = self.model(x_t, t_times, s_times, obs_cond)

                        with torch.no_grad():
                            pred_nograd = self.model(x_r, r_times, s_times, obs_cond)
    
                        time_weights = self.calculate_weights(s_times, t_times)
                        
                        # Reshape predictions to match expected dimensions
                        # Assuming pred_grad and pred_nograd are [B, sequence_length, action_dim]
                        # Reshape to [B, self.pred_horizon, self.obs_horizon]
                        model_outputs = {
                            'ys_t': pred_grad.reshape(B, self.pred_horizon, self.obs_horizon),
                            'w_scale': 1.0 / torch.abs((t_times - s_times) * self.sigma_data)
                        }
                        
                        stop_gradient_outputs = {
                            'ys_r': pred_nograd.reshape(B, self.pred_horizon, self.obs_horizon).detach()
                        }
                        
                        loss = self.loss(model_outputs, time_weights, stop_gradient_outputs)
                        
                        loss.backward()
                        optimizer.step()
                        lr_scheduler.step()

                        if batch_idx % 10 == 0:
                            print(f"Epoch {epoch_idx}, Batch {batch_idx}, Loss: {loss.item()}")
                        if batch_idx % 100 == 0:
                            # save model checkpoint
                            torch.save({
                                'epoch': epoch_idx,
                                'model_state_dict': self.model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': loss.item()
                            }, f"ckpts/model_checkpoint_{epoch_idx}_{batch_idx}_{loss.item()}.pth")
                                

In [21]:
#@markdown ### **Network Demo**

# observation and action dimensions corrsponding to
# the output of PushTEnv
obs_dim = 5
action_dim = 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

imm = RoboIMM(
    model=noise_pred_net,
    sigma_data=stats['action_normalized']['std'].mean(),
    obs_horizon=obs_horizon,
    pred_horizon=pred_horizon,
    num_particles=4
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))
diffusion_iter_s = torch.zeros((1,))


r = imm.sample(shape=(1, pred_horizon, action_dim), steps=10, global_cond=obs.flatten(start_dim=1))

# # the noise prediction network
# # takes noisy action, diffusion iteration and observation as input
# # predicts the noise added to action
# noise = noise_pred_net(
#     sample=noised_action,
#     timestep=diffusion_iter,
#     timestep_s=diffusion_iter_s,
#     global_cond=obs.flatten(start_dim=1))

# # illustration of removing noise
# # the actual noise removal is performed by NoiseScheduler
# # and is dependent on the diffusion noise schedule
# denoised_action = noised_action - noise

# # for this demo, we use DDPMScheduler with 100 diffusion iterations
# num_diffusion_iters = 100
# noise_scheduler = DDPMScheduler(
#     num_train_timesteps=num_diffusion_iters,
#     # the choise of beta schedule has big impact on performance
#     # we found squared cosine works the best
#     beta_schedule='squaredcos_cap_v2',
#     # clip output to [-1,1] to improve stability
#     clip_sample=True,
#     # our network predicts noise (instead of denoised action)
#     prediction_type='epsilon'
# )

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

number of parameters: 6.954880e+07


In [22]:
imm.train(dataloader, None, num_epochs=100, lr=1e-4, device='cuda')

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0, Batch 0, Loss: 0.015672294422984123




Epoch 0, Batch 10, Loss: 0.0002584146859589964




Epoch 0, Batch 20, Loss: 0.00032096196082420647




Epoch 0, Batch 30, Loss: 0.0003530262620188296




Epoch 0, Batch 40, Loss: 1.944993891811464e-05




Epoch 0, Batch 50, Loss: 8.247630830737762e-06




Epoch 0, Batch 60, Loss: 1.1387604899937287e-05




Epoch 0, Batch 70, Loss: 1.3222301276982762e-05




Epoch 0, Batch 80, Loss: 1.6778782082838006e-05




Epoch 0, Batch 90, Loss: 1.6118536223075353e-05


Epoch:   1%|          | 1/100 [00:31<52:27, 31.80s/it]

Epoch 1, Batch 0, Loss: 1.305439855059376e-05




Epoch 1, Batch 10, Loss: 1.8929025827674195e-05




Epoch 1, Batch 20, Loss: 1.610410981811583e-05




Epoch 1, Batch 30, Loss: 1.7683232726994902e-05




Epoch 1, Batch 40, Loss: 2.0564155420288444e-05




Epoch 1, Batch 50, Loss: 2.378621320531238e-05




Epoch 1, Batch 60, Loss: 5.278332173475064e-05




Epoch 1, Batch 70, Loss: 0.0001771646784618497




Epoch 1, Batch 80, Loss: 0.00115089095197618




Epoch 1, Batch 90, Loss: 0.05917278304696083


Epoch:   2%|▏         | 2/100 [01:03<51:49, 31.73s/it]

Epoch 2, Batch 0, Loss: 0.11555440723896027




Epoch 2, Batch 10, Loss: 0.11423725634813309




Epoch 2, Batch 20, Loss: 0.049466364085674286




Epoch 2, Batch 30, Loss: 0.09434743225574493




Epoch 2, Batch 40, Loss: 0.04129397124052048




Epoch 2, Batch 50, Loss: 0.01443843450397253




Epoch 2, Batch 60, Loss: 0.02595186047255993




Epoch 2, Batch 70, Loss: 0.053040433675050735




Epoch 2, Batch 80, Loss: 0.15042200684547424




Epoch 2, Batch 90, Loss: 0.11046651005744934


Epoch:   3%|▎         | 3/100 [01:35<51:18, 31.74s/it]

Epoch 3, Batch 0, Loss: 0.13947370648384094




Epoch 3, Batch 10, Loss: 0.10718759894371033




Epoch 3, Batch 20, Loss: 0.12327928841114044




Epoch 3, Batch 30, Loss: 0.11040857434272766




Epoch 3, Batch 40, Loss: 0.12958687543869019




Epoch 3, Batch 50, Loss: 0.09032652527093887




Epoch 3, Batch 60, Loss: 0.06047194078564644




Epoch 3, Batch 70, Loss: 0.026710595935583115




Epoch 3, Batch 80, Loss: 0.07637452334165573




Epoch 3, Batch 90, Loss: 0.05412893369793892


Epoch:   4%|▍         | 4/100 [02:06<50:43, 31.70s/it]

Epoch 4, Batch 0, Loss: 0.0684567242860794




Epoch 4, Batch 10, Loss: 0.0908544734120369




Epoch 4, Batch 20, Loss: 0.06310322135686874




Epoch 4, Batch 30, Loss: 0.04795510694384575




Epoch 4, Batch 40, Loss: 0.07244791090488434




Epoch 4, Batch 50, Loss: 0.047366220504045486




Epoch 4, Batch 60, Loss: 0.06089244782924652




Epoch 4, Batch 70, Loss: 0.05493906885385513




Epoch 4, Batch 80, Loss: 0.04088616743683815




Epoch 4, Batch 90, Loss: 0.05075584352016449


Epoch:   5%|▌         | 5/100 [02:38<50:10, 31.69s/it]

Epoch 5, Batch 0, Loss: 0.046727586537599564




Epoch 5, Batch 10, Loss: 0.04521546512842178




Epoch 5, Batch 20, Loss: 0.03047412633895874




Epoch 5, Batch 30, Loss: 0.06514973193407059




Epoch 5, Batch 40, Loss: 0.034814946353435516




Epoch 5, Batch 50, Loss: 0.03143209591507912




Epoch 5, Batch 60, Loss: 0.0423675999045372




Epoch 5, Batch 70, Loss: 0.041005268692970276




Epoch 5, Batch 80, Loss: 0.02780047059059143




Epoch 5, Batch 90, Loss: 0.03847089409828186


Epoch:   6%|▌         | 6/100 [03:10<49:38, 31.68s/it]

Epoch 6, Batch 0, Loss: 0.042468488216400146




Epoch 6, Batch 10, Loss: 0.03582752123475075




Epoch 6, Batch 20, Loss: 0.04197497293353081




Epoch 6, Batch 30, Loss: 0.04372038692235947




Epoch 6, Batch 40, Loss: 0.030264176428318024




Epoch 6, Batch 50, Loss: 0.03972426801919937




Epoch 6, Batch 60, Loss: 0.03318924084305763




Epoch 6, Batch 70, Loss: 0.04083377122879028




Epoch 6, Batch 80, Loss: 0.031066540628671646




Epoch 6, Batch 90, Loss: 0.03563742712140083


Epoch:   7%|▋         | 7/100 [03:41<49:04, 31.66s/it]

Epoch 7, Batch 0, Loss: 0.0512005016207695




Epoch 7, Batch 10, Loss: 0.0513012632727623




Epoch 7, Batch 20, Loss: 0.020101284608244896




Epoch 7, Batch 30, Loss: 0.025754904374480247




Epoch 7, Batch 40, Loss: 0.03729740157723427




Epoch 7, Batch 50, Loss: 0.05579278618097305




Epoch 7, Batch 60, Loss: 0.03533906862139702




Epoch 7, Batch 70, Loss: 0.030943501740694046




Epoch 7, Batch 80, Loss: 0.023837991058826447




Epoch 7, Batch 90, Loss: 0.031722165644168854


Epoch:   8%|▊         | 8/100 [04:13<48:29, 31.63s/it]

Epoch 8, Batch 0, Loss: 0.059900663793087006




Epoch 8, Batch 10, Loss: 0.0232327189296484




Epoch 8, Batch 20, Loss: 0.03069567121565342




Epoch 8, Batch 30, Loss: 0.057828035205602646




Epoch 8, Batch 40, Loss: 0.06407923251390457




Epoch 8, Batch 50, Loss: 0.09599960595369339




Epoch 8, Batch 60, Loss: 0.09416241943836212




Epoch 8, Batch 70, Loss: 0.09208940714597702




Epoch 8, Batch 80, Loss: 0.08803592622280121




Epoch 8, Batch 90, Loss: 0.0771113783121109


Epoch:   9%|▉         | 9/100 [04:44<47:57, 31.62s/it]

Epoch 9, Batch 0, Loss: 0.07505542784929276




Epoch 9, Batch 10, Loss: 0.044002119451761246




Epoch 9, Batch 20, Loss: 0.06826872378587723




Epoch 9, Batch 30, Loss: 0.08140452951192856




Epoch 9, Batch 40, Loss: 0.05101313441991806




Epoch 9, Batch 50, Loss: 0.06702733039855957




Epoch 9, Batch 60, Loss: 0.0853688195347786




Epoch 9, Batch 70, Loss: 0.05276021361351013




Epoch 9, Batch 80, Loss: 0.05414295196533203




Epoch 9, Batch 90, Loss: 0.04993187263607979


Epoch:  10%|█         | 10/100 [05:16<47:28, 31.65s/it]

Epoch 10, Batch 0, Loss: 0.057552505284547806




Epoch 10, Batch 10, Loss: 0.08810373395681381




Epoch 10, Batch 20, Loss: 0.1442781239748001




Epoch 10, Batch 30, Loss: 0.0425603985786438




Epoch 10, Batch 40, Loss: 0.05190437659621239




Epoch 10, Batch 50, Loss: 0.08308536559343338




Epoch 10, Batch 60, Loss: 0.09771545976400375




Epoch 10, Batch 70, Loss: 0.07673957943916321




Epoch 10, Batch 80, Loss: 0.09290224313735962




Epoch 10, Batch 90, Loss: 0.051111962646245956


Epoch:  11%|█         | 11/100 [05:48<46:54, 31.62s/it]

Epoch 11, Batch 0, Loss: 0.05839889124035835




Epoch 11, Batch 10, Loss: 0.05998969450592995




Epoch 11, Batch 20, Loss: 0.08204451203346252




Epoch 11, Batch 30, Loss: 0.13396276533603668




Epoch 11, Batch 40, Loss: 0.11563524603843689




Epoch 11, Batch 50, Loss: 0.10997078567743301




Epoch 11, Batch 60, Loss: 0.12653648853302002




Epoch 11, Batch 70, Loss: 0.09544277936220169




Epoch 11, Batch 80, Loss: 0.10743221640586853




Epoch 11, Batch 90, Loss: 0.09001465141773224


Epoch:  12%|█▏        | 12/100 [06:20<46:29, 31.70s/it]

Epoch 12, Batch 0, Loss: 0.07381722331047058




Epoch 12, Batch 10, Loss: 0.06841756403446198




Epoch 12, Batch 20, Loss: 0.09680405259132385




Epoch 12, Batch 30, Loss: 0.08131814002990723




Epoch 12, Batch 40, Loss: 0.05859913304448128




Epoch 12, Batch 50, Loss: 0.09726405888795853




Epoch 12, Batch 60, Loss: 0.10120426118373871




Epoch 12, Batch 70, Loss: 0.10658863931894302




Epoch 12, Batch 80, Loss: 0.07400576770305634




Epoch 12, Batch 90, Loss: 0.059269774705171585


Epoch:  13%|█▎        | 13/100 [06:51<45:57, 31.70s/it]

Epoch 13, Batch 0, Loss: 0.05250972509384155




Epoch 13, Batch 10, Loss: 0.05958615988492966




Epoch 13, Batch 20, Loss: 0.05001519247889519




Epoch 13, Batch 30, Loss: 0.06755123287439346




Epoch 13, Batch 40, Loss: 0.08189750462770462




Epoch 13, Batch 50, Loss: 0.06094406545162201




Epoch 13, Batch 60, Loss: 0.0742131844162941




Epoch 13, Batch 70, Loss: 0.07772607356309891




Epoch 13, Batch 80, Loss: 0.07397135347127914




Epoch 13, Batch 90, Loss: 0.05438385531306267


Epoch:  14%|█▍        | 14/100 [07:23<45:24, 31.68s/it]

Epoch 14, Batch 0, Loss: 0.08582684397697449




Epoch 14, Batch 10, Loss: 0.07049419730901718




Epoch 14, Batch 20, Loss: 0.033660292625427246




Epoch 14, Batch 30, Loss: 0.07998335361480713




Epoch 14, Batch 40, Loss: 0.06476937234401703




Epoch 14, Batch 50, Loss: 0.05261191353201866




Epoch 14, Batch 60, Loss: 0.04843280836939812




Epoch 14, Batch 70, Loss: 0.030875492841005325




Epoch 14, Batch 80, Loss: 0.08273051679134369




Epoch 14, Batch 90, Loss: 0.05671673268079758


Epoch:  15%|█▌        | 15/100 [07:55<44:50, 31.66s/it]

Epoch 15, Batch 0, Loss: 0.05552753061056137




Epoch 15, Batch 10, Loss: 0.05643558129668236




Epoch 15, Batch 20, Loss: 0.05604921653866768




Epoch 15, Batch 30, Loss: 0.06775238364934921




Epoch 15, Batch 40, Loss: 0.06818769872188568




Epoch 15, Batch 50, Loss: 0.05619483068585396




Epoch 15, Batch 60, Loss: 0.03473120555281639




Epoch 15, Batch 70, Loss: 0.08440142124891281




Epoch 15, Batch 80, Loss: 0.0652615949511528




Epoch 15, Batch 90, Loss: 0.0697924867272377


Epoch:  16%|█▌        | 16/100 [08:26<44:17, 31.64s/it]

Epoch 16, Batch 0, Loss: 0.0700179785490036




Epoch 16, Batch 10, Loss: 0.038625724613666534




Epoch 16, Batch 20, Loss: 0.06328245252370834




Epoch 16, Batch 30, Loss: 0.07490521669387817




Epoch 16, Batch 40, Loss: 0.08990135043859482




Epoch 16, Batch 50, Loss: 0.049406278878450394




Epoch 16, Batch 60, Loss: 0.07374142110347748




Epoch 16, Batch 70, Loss: 0.06056509166955948




Epoch 16, Batch 80, Loss: 0.10356418043375015




Epoch 16, Batch 90, Loss: 0.09422662109136581


Epoch:  17%|█▋        | 17/100 [08:58<43:45, 31.64s/it]

Epoch 17, Batch 0, Loss: 0.08470176160335541




Epoch 17, Batch 10, Loss: 0.09507080912590027




Epoch 17, Batch 20, Loss: 0.05877156928181648




Epoch 17, Batch 30, Loss: 0.04557032138109207




Epoch 17, Batch 40, Loss: 0.05534149333834648




Epoch 17, Batch 50, Loss: 0.052199412137269974




Epoch 17, Batch 60, Loss: 0.06790190190076828




Epoch 17, Batch 70, Loss: 0.08549430221319199




Epoch 17, Batch 80, Loss: 0.04727267846465111




Epoch 17, Batch 90, Loss: 0.051959700882434845


Epoch:  18%|█▊        | 18/100 [09:31<43:42, 31.98s/it]

Epoch 18, Batch 0, Loss: 0.05381127819418907




Epoch 18, Batch 10, Loss: 0.046824827790260315




Epoch 18, Batch 20, Loss: 0.0869290679693222




Epoch 18, Batch 30, Loss: 0.0564202219247818




Epoch 18, Batch 40, Loss: 0.056556668132543564




Epoch 18, Batch 50, Loss: 0.04220356047153473




Epoch 18, Batch 60, Loss: 0.06543425470590591




Epoch 18, Batch 70, Loss: 0.03420913591980934




Epoch 18, Batch 80, Loss: 0.03090311773121357




Epoch 18, Batch 90, Loss: 0.04667620360851288


Epoch:  19%|█▉        | 19/100 [10:02<43:01, 31.87s/it]

Epoch 19, Batch 0, Loss: 0.030694300308823586




Epoch 19, Batch 10, Loss: 0.02704625204205513




Epoch 19, Batch 20, Loss: 0.033093422651290894




Epoch 19, Batch 30, Loss: 0.06169194355607033




Epoch 19, Batch 40, Loss: 0.05664736405014992




Epoch 19, Batch 50, Loss: 0.046082064509391785




Epoch 19, Batch 60, Loss: 0.07014502584934235




Epoch 19, Batch 70, Loss: 0.05182274430990219




Epoch 19, Batch 80, Loss: 0.053889818489551544




Epoch 19, Batch 90, Loss: 0.04912113770842552


Epoch:  20%|██        | 20/100 [10:34<42:23, 31.80s/it]

Epoch 20, Batch 0, Loss: 0.04921679198741913




Epoch 20, Batch 10, Loss: 0.0434122160077095




Epoch 20, Batch 20, Loss: 0.047627128660678864




Epoch 20, Batch 30, Loss: 0.041741810739040375




Epoch 20, Batch 40, Loss: 0.04429085925221443




Epoch 20, Batch 50, Loss: 0.03562973812222481




Epoch 20, Batch 60, Loss: 0.03683829680085182




Epoch 20, Batch 70, Loss: 0.029024425894021988




Epoch 20, Batch 80, Loss: 0.05152585357427597




Epoch 20, Batch 90, Loss: 0.05960836261510849


Epoch:  21%|██        | 21/100 [11:05<41:47, 31.75s/it]

Epoch 21, Batch 0, Loss: 0.08501456677913666




Epoch 21, Batch 10, Loss: 0.09815256297588348




Epoch 21, Batch 20, Loss: 0.08371341228485107




Epoch 21, Batch 30, Loss: 0.06353035569190979




Epoch 21, Batch 40, Loss: 0.06570456922054291




Epoch 21, Batch 50, Loss: 0.06277373433113098




Epoch 21, Batch 60, Loss: 0.03904159739613533




Epoch 21, Batch 70, Loss: 0.07852499932050705




Epoch 21, Batch 80, Loss: 0.04359494894742966




Epoch 21, Batch 90, Loss: 0.04410240799188614


Epoch:  22%|██▏       | 22/100 [11:37<41:10, 31.67s/it]

Epoch 22, Batch 0, Loss: 0.056425340473651886




Epoch 22, Batch 10, Loss: 0.05835435166954994




Epoch 22, Batch 20, Loss: 0.043081216514110565




Epoch 22, Batch 30, Loss: 0.06911081820726395




Epoch 22, Batch 40, Loss: 0.047907132655382156




Epoch 22, Batch 50, Loss: 0.034339290112257004




Epoch 22, Batch 60, Loss: 0.04245131462812424




Epoch 22, Batch 70, Loss: 0.03533429279923439




Epoch 22, Batch 80, Loss: 0.05981236696243286




Epoch 22, Batch 90, Loss: 0.0320173017680645


Epoch:  23%|██▎       | 23/100 [12:09<40:36, 31.64s/it]

Epoch 23, Batch 0, Loss: 0.038070034235715866




Epoch 23, Batch 10, Loss: 0.03465184569358826




Epoch 23, Batch 20, Loss: 0.03532801568508148




Epoch 23, Batch 30, Loss: 0.030759567394852638




Epoch 23, Batch 40, Loss: 0.028479890897870064




Epoch 23, Batch 50, Loss: 0.06064712628722191




Epoch 23, Batch 60, Loss: 0.0312652550637722




Epoch 23, Batch 70, Loss: 0.025924228131771088




Epoch 23, Batch 80, Loss: 0.032710347324609756




Epoch 23, Batch 90, Loss: 0.03079472854733467


Epoch:  24%|██▍       | 24/100 [12:40<40:05, 31.65s/it]

Epoch 24, Batch 0, Loss: 0.033249616622924805




Epoch 24, Batch 10, Loss: 0.03652862459421158




Epoch 24, Batch 20, Loss: 0.02498527243733406




Epoch 24, Batch 30, Loss: 0.03538942337036133




Epoch 24, Batch 40, Loss: 0.03671669214963913




Epoch 24, Batch 50, Loss: 0.038212090730667114




Epoch 24, Batch 60, Loss: 0.04708727449178696




Epoch 24, Batch 70, Loss: 0.04072742909193039




Epoch 24, Batch 80, Loss: 0.04705805331468582




Epoch 24, Batch 90, Loss: 0.03207019343972206


Epoch:  25%|██▌       | 25/100 [13:12<39:30, 31.61s/it]

Epoch 25, Batch 0, Loss: 0.021063992753624916




Epoch 25, Batch 10, Loss: 0.0592384859919548




Epoch 25, Batch 20, Loss: 0.03741519898176193




Epoch 25, Batch 30, Loss: 0.04819036275148392




Epoch 25, Batch 40, Loss: 0.05416630953550339




Epoch 25, Batch 50, Loss: 0.07205633074045181




Epoch 25, Batch 60, Loss: 0.06565690785646439




Epoch 25, Batch 70, Loss: 0.04073157161474228




Epoch 25, Batch 80, Loss: 0.043830111622810364




Epoch 25, Batch 90, Loss: 0.062471870332956314


Epoch:  26%|██▌       | 26/100 [13:43<38:58, 31.60s/it]

Epoch 26, Batch 0, Loss: 0.05372568964958191




Epoch 26, Batch 10, Loss: 0.03336005657911301




Epoch 26, Batch 20, Loss: 0.025378525257110596




Epoch 26, Batch 30, Loss: 0.03765444830060005




Epoch 26, Batch 40, Loss: 0.03631909564137459




Epoch 26, Batch 50, Loss: 0.05646519735455513




Epoch 26, Batch 60, Loss: 0.04138592630624771




Epoch 26, Batch 70, Loss: 0.05152040347456932




Epoch 26, Batch 80, Loss: 0.026290010660886765




Epoch 26, Batch 90, Loss: 0.048964858055114746


Epoch:  27%|██▋       | 27/100 [14:15<38:29, 31.63s/it]

Epoch 27, Batch 0, Loss: 0.020163901150226593




Epoch 27, Batch 10, Loss: 0.02291061356663704




Epoch 27, Batch 20, Loss: 0.03432069718837738




Epoch 27, Batch 30, Loss: 0.08974862098693848




Epoch 27, Batch 40, Loss: 0.0873563289642334




Epoch 27, Batch 50, Loss: 0.07626917958259583




Epoch 27, Batch 60, Loss: 0.050022758543491364




Epoch 27, Batch 70, Loss: 0.05668250471353531




Epoch 27, Batch 80, Loss: 0.06365892291069031




Epoch 27, Batch 90, Loss: 0.0551142655313015


Epoch:  28%|██▊       | 28/100 [14:47<37:55, 31.61s/it]

Epoch 28, Batch 0, Loss: 0.049944255501031876




Epoch 28, Batch 10, Loss: 0.045868657529354095




Epoch 28, Batch 20, Loss: 0.03960757330060005




Epoch 28, Batch 30, Loss: 0.04731199890375137




Epoch 28, Batch 40, Loss: 0.06682855635881424




Epoch 28, Batch 50, Loss: 0.043169498443603516




Epoch 28, Batch 60, Loss: 0.05568912252783775




Epoch 28, Batch 70, Loss: 0.03921816870570183




Epoch 28, Batch 80, Loss: 0.06088819354772568




Epoch 28, Batch 90, Loss: 0.06986235082149506


Epoch:  29%|██▉       | 29/100 [15:20<37:55, 32.05s/it]

Epoch 29, Batch 0, Loss: 0.089528888463974




Epoch 29, Batch 10, Loss: 0.04865358769893646




Epoch 29, Batch 20, Loss: 0.13097937405109406




Epoch 29, Batch 30, Loss: 0.03316792845726013




Epoch 29, Batch 40, Loss: 0.047171451151371




Epoch 29, Batch 50, Loss: 0.03468158841133118




Epoch 29, Batch 60, Loss: 0.030467456206679344




Epoch 29, Batch 70, Loss: 0.028373513370752335




Epoch 29, Batch 80, Loss: 0.039735496044158936




Epoch 29, Batch 90, Loss: 0.04294941574335098


Epoch:  30%|███       | 30/100 [15:51<37:17, 31.96s/it]

Epoch 30, Batch 0, Loss: 0.07558531314134598




Epoch 30, Batch 10, Loss: 0.08176468312740326




Epoch 30, Batch 20, Loss: 0.04136767238378525




Epoch 30, Batch 30, Loss: 0.04914656654000282




Epoch 30, Batch 40, Loss: 0.044868893921375275




Epoch 30, Batch 50, Loss: 0.061796121299266815




Epoch 30, Batch 60, Loss: 0.0445658378303051




Epoch 30, Batch 70, Loss: 0.04331878200173378




Epoch 30, Batch 80, Loss: 0.04805447906255722




Epoch 30, Batch 90, Loss: 0.032382793724536896


Epoch:  31%|███       | 31/100 [16:23<36:39, 31.88s/it]

Epoch 31, Batch 0, Loss: 0.04827161133289337




Epoch 31, Batch 10, Loss: 0.03852679952979088




Epoch 31, Batch 20, Loss: 0.04225360229611397




Epoch 31, Batch 30, Loss: 0.028655774891376495




Epoch 31, Batch 40, Loss: 0.04785682633519173




Epoch 31, Batch 50, Loss: 0.035656776279211044




Epoch 31, Batch 60, Loss: 0.0180327408015728




Epoch 31, Batch 70, Loss: 0.02393188513815403




Epoch 31, Batch 80, Loss: 0.0404076874256134




Epoch 31, Batch 90, Loss: 0.08554762601852417


Epoch:  32%|███▏      | 32/100 [16:55<36:06, 31.87s/it]

Epoch 32, Batch 0, Loss: 0.058687180280685425




Epoch 32, Batch 10, Loss: 0.0815172865986824




Epoch 32, Batch 20, Loss: 0.03398514539003372




Epoch 32, Batch 30, Loss: 0.04052111878991127




Epoch 32, Batch 40, Loss: 0.030858825892210007




Epoch 32, Batch 50, Loss: 0.0319548137485981




Epoch 32, Batch 60, Loss: 0.047791771590709686




Epoch 32, Batch 70, Loss: 0.043496355414390564




Epoch 32, Batch 80, Loss: 0.025467202067375183




Epoch 32, Batch 90, Loss: 0.051706403493881226


Epoch:  33%|███▎      | 33/100 [17:27<35:30, 31.79s/it]

Epoch 33, Batch 0, Loss: 0.05398767068982124




Epoch 33, Batch 10, Loss: 0.03485621139407158




Epoch 33, Batch 20, Loss: 0.03781149163842201




Epoch 33, Batch 30, Loss: 0.03206922858953476




Epoch 33, Batch 40, Loss: 0.03356743976473808




Epoch 33, Batch 50, Loss: 0.027132995426654816




Epoch 33, Batch 60, Loss: 0.03883069381117821




Epoch 33, Batch 70, Loss: 0.029946306720376015




Epoch 33, Batch 80, Loss: 0.026174599304795265




Epoch 33, Batch 90, Loss: 0.06603265553712845


Epoch:  34%|███▍      | 34/100 [17:58<34:55, 31.75s/it]

Epoch 34, Batch 0, Loss: 0.045337893068790436




Epoch 34, Batch 10, Loss: 0.02726653590798378




Epoch 34, Batch 20, Loss: 0.027235746383666992




Epoch 34, Batch 30, Loss: 0.021574227139353752




Epoch 34, Batch 40, Loss: 0.045146360993385315




Epoch 34, Batch 50, Loss: 0.03270601108670235




Epoch 34, Batch 60, Loss: 0.027256641536951065




Epoch 34, Batch 70, Loss: 0.0298237893730402




Epoch 34, Batch 80, Loss: 0.03203393891453743




Epoch 34, Batch 90, Loss: 0.027838626876473427


Epoch:  35%|███▌      | 35/100 [18:30<34:26, 31.79s/it]

Epoch 35, Batch 0, Loss: 0.027684293687343597




Epoch 35, Batch 10, Loss: 0.028410617262125015




Epoch 35, Batch 20, Loss: 0.04910638928413391




Epoch 35, Batch 30, Loss: 0.02835247665643692




Epoch 35, Batch 40, Loss: 0.023059193044900894




Epoch 35, Batch 50, Loss: 0.04623445123434067




Epoch 35, Batch 60, Loss: 0.029169630259275436




Epoch 35, Batch 70, Loss: 0.05296581611037254




Epoch 35, Batch 80, Loss: 0.032851945608854294




Epoch 35, Batch 90, Loss: 0.02652275189757347


Epoch:  36%|███▌      | 36/100 [19:02<33:55, 31.80s/it]

Epoch 36, Batch 0, Loss: 0.022844621911644936




Epoch 36, Batch 10, Loss: 0.035322654992341995




Epoch 36, Batch 20, Loss: 0.019052637740969658




Epoch 36, Batch 30, Loss: 0.0454099215567112




Epoch 36, Batch 40, Loss: 0.041013821959495544




Epoch 36, Batch 50, Loss: 0.03861114755272865




Epoch 36, Batch 60, Loss: 0.019723905250430107




Epoch 36, Batch 70, Loss: 0.036212339997291565




Epoch 36, Batch 80, Loss: 0.03720578923821449




Epoch 36, Batch 90, Loss: 0.03468235954642296


Epoch:  37%|███▋      | 37/100 [19:34<33:26, 31.85s/it]

Epoch 37, Batch 0, Loss: 0.03637182340025902




Epoch 37, Batch 10, Loss: 0.03853893280029297




Epoch 37, Batch 20, Loss: 0.03860520198941231




Epoch 37, Batch 30, Loss: 0.023121559992432594




Epoch 37, Batch 40, Loss: 0.01681201159954071




Epoch 37, Batch 50, Loss: 0.02605554275214672




Epoch 37, Batch 60, Loss: 0.04227181896567345




Epoch 37, Batch 70, Loss: 0.0631844699382782




Epoch 37, Batch 80, Loss: 0.04879361763596535




Epoch 37, Batch 90, Loss: 0.038135234266519547


Epoch:  38%|███▊      | 38/100 [20:06<32:54, 31.84s/it]

Epoch 38, Batch 0, Loss: 0.05872931331396103




Epoch 38, Batch 10, Loss: 0.031226633116602898




Epoch 38, Batch 20, Loss: 0.05522364377975464




Epoch 38, Batch 30, Loss: 0.04931415989995003




Epoch 38, Batch 40, Loss: 0.06129459664225578




Epoch 38, Batch 50, Loss: 0.06403027474880219




Epoch 38, Batch 60, Loss: 0.05538936331868172




Epoch 38, Batch 70, Loss: 0.04292511194944382




Epoch 38, Batch 80, Loss: 0.055188585072755814




Epoch 38, Batch 90, Loss: 0.07257276028394699


Epoch:  39%|███▉      | 39/100 [20:38<32:25, 31.90s/it]

Epoch 39, Batch 0, Loss: 0.07386764883995056




Epoch 39, Batch 10, Loss: 0.05812130868434906




Epoch 39, Batch 20, Loss: 0.051219966262578964




Epoch 39, Batch 30, Loss: 0.06431757658720016




Epoch 39, Batch 40, Loss: 0.05708467960357666




Epoch 39, Batch 50, Loss: 0.05299711599946022




Epoch 39, Batch 60, Loss: 0.043503034859895706




Epoch 39, Batch 70, Loss: 0.033651307225227356




Epoch 39, Batch 80, Loss: 0.048702139407396317




Epoch 39, Batch 90, Loss: 0.030902232974767685


Epoch:  40%|████      | 40/100 [21:10<31:57, 31.96s/it]

Epoch 40, Batch 0, Loss: 0.0448773019015789




Epoch 40, Batch 10, Loss: 0.055939096957445145




Epoch 40, Batch 20, Loss: 0.03501896560192108




Epoch 40, Batch 30, Loss: 0.03753204271197319




Epoch 40, Batch 40, Loss: 0.036198049783706665




Epoch 40, Batch 50, Loss: 0.04733648523688316




Epoch 40, Batch 60, Loss: 0.021604904904961586




Epoch 40, Batch 70, Loss: 0.019354727119207382




Epoch 40, Batch 80, Loss: 0.028854666277766228




Epoch 40, Batch 90, Loss: 0.03251180797815323


Epoch:  41%|████      | 41/100 [21:42<31:25, 31.95s/it]

Epoch 41, Batch 0, Loss: 0.036146897822618484




Epoch 41, Batch 10, Loss: 0.0385451503098011




Epoch 41, Batch 20, Loss: 0.045885153114795685




Epoch 41, Batch 30, Loss: 0.03926130756735802




Epoch 41, Batch 40, Loss: 0.039078518748283386




Epoch 41, Batch 50, Loss: 0.018699288368225098




Epoch 41, Batch 60, Loss: 0.020272938534617424




Epoch 41, Batch 70, Loss: 0.0296490415930748




Epoch 41, Batch 80, Loss: 0.0328153520822525




Epoch 41, Batch 90, Loss: 0.03684351220726967


Epoch:  42%|████▏     | 42/100 [22:14<30:53, 31.95s/it]

Epoch 42, Batch 0, Loss: 0.03198004886507988




Epoch 42, Batch 10, Loss: 0.05827144905924797




Epoch 42, Batch 20, Loss: 0.05094094201922417




Epoch 42, Batch 30, Loss: 0.051560238003730774




Epoch 42, Batch 40, Loss: 0.03661981225013733




Epoch 42, Batch 50, Loss: 0.07256348431110382




Epoch 42, Batch 60, Loss: 0.053420037031173706




Epoch 42, Batch 70, Loss: 0.02736598812043667




Epoch 42, Batch 80, Loss: 0.03037046268582344




Epoch 42, Batch 90, Loss: 0.04001978412270546


Epoch:  43%|████▎     | 43/100 [22:46<30:20, 31.94s/it]

Epoch 43, Batch 0, Loss: 0.022396283224225044




Epoch 43, Batch 10, Loss: 0.03136849030852318




Epoch 43, Batch 20, Loss: 0.03178466111421585




Epoch 43, Batch 30, Loss: 0.02725083753466606




Epoch 43, Batch 40, Loss: 0.02710346132516861




Epoch 43, Batch 50, Loss: 0.03967530280351639




Epoch 43, Batch 60, Loss: 0.02376527152955532




Epoch 43, Batch 70, Loss: 0.01897444948554039




Epoch 43, Batch 80, Loss: 0.06059064343571663




Epoch 43, Batch 90, Loss: 0.04211999475955963


Epoch:  44%|████▍     | 44/100 [23:17<29:43, 31.84s/it]

Epoch 44, Batch 0, Loss: 0.051598913967609406




Epoch 44, Batch 10, Loss: 0.043157171458005905




Epoch 44, Batch 20, Loss: 0.06541644781827927




Epoch 44, Batch 30, Loss: 0.06598937511444092




Epoch 44, Batch 40, Loss: 0.05845070257782936




Epoch 44, Batch 50, Loss: 0.0530817024409771




Epoch 44, Batch 60, Loss: 0.038725998252630234




Epoch 44, Batch 70, Loss: 0.032821863889694214




Epoch 44, Batch 80, Loss: 0.04924089461565018




Epoch 44, Batch 90, Loss: 0.03800319880247116


Epoch:  45%|████▌     | 45/100 [23:49<29:09, 31.81s/it]

Epoch 45, Batch 0, Loss: 0.0508289597928524




Epoch 45, Batch 10, Loss: 0.026646260172128677




Epoch 45, Batch 20, Loss: 0.022455591708421707




Epoch 45, Batch 30, Loss: 0.02828943356871605




Epoch 45, Batch 40, Loss: 0.043264806270599365




Epoch 45, Batch 50, Loss: 0.03566209599375725




Epoch 45, Batch 60, Loss: 0.02665109559893608




Epoch 45, Batch 70, Loss: 0.03122846595942974




Epoch 45, Batch 80, Loss: 0.06281714141368866




Epoch 45, Batch 90, Loss: 0.033040180802345276


Epoch:  46%|████▌     | 46/100 [24:21<28:36, 31.78s/it]

Epoch 46, Batch 0, Loss: 0.02869221568107605




Epoch 46, Batch 10, Loss: 0.03702940419316292




Epoch 46, Batch 20, Loss: 0.05038704723119736




Epoch 46, Batch 30, Loss: 0.040483418852090836




Epoch 46, Batch 40, Loss: 0.02131926827132702




Epoch 46, Batch 50, Loss: 0.04445919021964073




Epoch 46, Batch 60, Loss: 0.040538471192121506




Epoch 46, Batch 70, Loss: 0.036049775779247284




Epoch 46, Batch 80, Loss: 0.04200649634003639




Epoch 46, Batch 90, Loss: 0.03620850294828415


Epoch:  47%|████▋     | 47/100 [24:52<28:04, 31.79s/it]

Epoch 47, Batch 0, Loss: 0.030415980145335197




Epoch 47, Batch 10, Loss: 0.0334952175617218




Epoch 47, Batch 20, Loss: 0.03393774479627609




Epoch 47, Batch 30, Loss: 0.03598718345165253




Epoch 47, Batch 40, Loss: 0.038031794130802155




Epoch 47, Batch 50, Loss: 0.046542905271053314




Epoch 47, Batch 60, Loss: 0.03176078945398331




Epoch 47, Batch 70, Loss: 0.035128623247146606




Epoch 47, Batch 80, Loss: 0.0293742623180151




Epoch 47, Batch 90, Loss: 0.05334609001874924


Epoch:  48%|████▊     | 48/100 [25:24<27:32, 31.78s/it]

Epoch 48, Batch 0, Loss: 0.03904075548052788




Epoch 48, Batch 10, Loss: 0.037796154618263245




Epoch 48, Batch 20, Loss: 0.03090851567685604




Epoch 48, Batch 30, Loss: 0.026473455131053925




Epoch 48, Batch 40, Loss: 0.037214286625385284




Epoch 48, Batch 50, Loss: 0.043199844658374786




Epoch 48, Batch 60, Loss: 0.03549010679125786




Epoch 48, Batch 70, Loss: 0.05648975446820259




Epoch 48, Batch 80, Loss: 0.04071149602532387




Epoch 48, Batch 90, Loss: 0.032766956835985184


Epoch:  49%|████▉     | 49/100 [25:56<26:58, 31.74s/it]

Epoch 49, Batch 0, Loss: 0.04901941865682602




Epoch 49, Batch 10, Loss: 0.03834553435444832




Epoch 49, Batch 20, Loss: 0.03892950341105461




Epoch 49, Batch 30, Loss: 0.03740430623292923




Epoch 49, Batch 40, Loss: 0.03656252846121788




Epoch 49, Batch 50, Loss: 0.03739872947335243




Epoch 49, Batch 60, Loss: 0.03333720564842224




Epoch 49, Batch 70, Loss: 0.05200624093413353




Epoch 49, Batch 80, Loss: 0.02276599407196045




Epoch 49, Batch 90, Loss: 0.04360572248697281


Epoch:  50%|█████     | 50/100 [26:27<26:25, 31.71s/it]

Epoch 50, Batch 0, Loss: 0.04012254625558853




Epoch 50, Batch 10, Loss: 0.033020853996276855




Epoch 50, Batch 20, Loss: 0.029576368629932404




Epoch 50, Batch 30, Loss: 0.019999071955680847




Epoch 50, Batch 40, Loss: 0.010731209069490433




Epoch 50, Batch 50, Loss: 0.031711723655462265




Epoch 50, Batch 60, Loss: 0.047296129167079926




Epoch 50, Batch 70, Loss: 0.021782787516713142




Epoch 50, Batch 80, Loss: 0.039419401437044144




Epoch 50, Batch 90, Loss: 0.022944310680031776


Epoch:  51%|█████     | 51/100 [26:59<25:55, 31.74s/it]

Epoch 51, Batch 0, Loss: 0.014009171165525913




Epoch 51, Batch 10, Loss: 0.04062747210264206




Epoch 51, Batch 20, Loss: 0.029482532292604446




Epoch 51, Batch 30, Loss: 0.03666293993592262




Epoch 51, Batch 40, Loss: 0.027956033125519753




Epoch 51, Batch 50, Loss: 0.04770239070057869




Epoch 51, Batch 60, Loss: 0.035890378057956696




Epoch 51, Batch 70, Loss: 0.03357192128896713




Epoch 51, Batch 80, Loss: 0.03494755178689957




Epoch 51, Batch 90, Loss: 0.04919365048408508


Epoch:  52%|█████▏    | 52/100 [27:31<25:22, 31.72s/it]

Epoch 52, Batch 0, Loss: 0.04313397407531738




Epoch 52, Batch 10, Loss: 0.043367695063352585




Epoch 52, Batch 20, Loss: 0.023089148104190826




Epoch 52, Batch 30, Loss: 0.024899670854210854




Epoch 52, Batch 40, Loss: 0.03796384856104851




Epoch 52, Batch 50, Loss: 0.022096658125519753




Epoch 52, Batch 60, Loss: 0.017786594107747078




Epoch 52, Batch 70, Loss: 0.020270362496376038




Epoch 52, Batch 80, Loss: 0.026149485260248184




Epoch 52, Batch 90, Loss: 0.020791878923773766


Epoch:  53%|█████▎    | 53/100 [28:03<24:51, 31.74s/it]

Epoch 53, Batch 0, Loss: 0.031041881069540977




Epoch 53, Batch 10, Loss: 0.062397126108407974




Epoch 53, Batch 20, Loss: 0.04072309657931328




Epoch 53, Batch 30, Loss: 0.034287769347429276




Epoch 53, Batch 40, Loss: 0.04429177567362785




Epoch 53, Batch 50, Loss: 0.029352473095059395




Epoch 53, Batch 60, Loss: 0.028224626556038857




Epoch 53, Batch 70, Loss: 0.02794588729739189




Epoch 53, Batch 80, Loss: 0.027945099398493767




Epoch 53, Batch 90, Loss: 0.019563229754567146


Epoch:  54%|█████▍    | 54/100 [28:35<24:20, 31.74s/it]

Epoch 54, Batch 0, Loss: 0.02316942997276783




Epoch 54, Batch 10, Loss: 0.034254785627126694




Epoch 54, Batch 20, Loss: 0.03255289047956467




Epoch 54, Batch 30, Loss: 0.04197417572140694




Epoch 54, Batch 40, Loss: 0.041178807616233826




Epoch 54, Batch 50, Loss: 0.02622876688838005




Epoch 54, Batch 60, Loss: 0.0431692898273468




Epoch 54, Batch 70, Loss: 0.030642399564385414




Epoch 54, Batch 80, Loss: 0.02568206936120987




Epoch 54, Batch 90, Loss: 0.03257337585091591


Epoch:  55%|█████▌    | 55/100 [29:06<23:48, 31.74s/it]

Epoch 55, Batch 0, Loss: 0.018118280917406082




Epoch 55, Batch 10, Loss: 0.019270144402980804




Epoch 55, Batch 20, Loss: 0.03713497519493103




Epoch 55, Batch 30, Loss: 0.02864423580467701




Epoch 55, Batch 40, Loss: 0.03721364587545395




Epoch 55, Batch 50, Loss: 0.02097330614924431




Epoch 55, Batch 60, Loss: 0.03514274209737778




Epoch 55, Batch 70, Loss: 0.036320582032203674




Epoch 55, Batch 80, Loss: 0.026257218793034554




Epoch 55, Batch 90, Loss: 0.03280285373330116


Epoch:  56%|█████▌    | 56/100 [29:38<23:15, 31.71s/it]

Epoch 56, Batch 0, Loss: 0.02551681362092495




Epoch 56, Batch 10, Loss: 0.04943891242146492




Epoch 56, Batch 20, Loss: 0.03491489216685295




Epoch 56, Batch 30, Loss: 0.020574569702148438




Epoch 56, Batch 40, Loss: 0.014209755696356297




Epoch 56, Batch 50, Loss: 0.02767040766775608




Epoch 56, Batch 60, Loss: 0.04063166677951813




Epoch 56, Batch 70, Loss: 0.023203624412417412




Epoch 56, Batch 80, Loss: 0.033428795635700226




Epoch 56, Batch 90, Loss: 0.021487146615982056


Epoch:  57%|█████▋    | 57/100 [30:10<22:47, 31.79s/it]

Epoch 57, Batch 0, Loss: 0.027930665761232376




Epoch 57, Batch 10, Loss: 0.03503219410777092




Epoch 57, Batch 20, Loss: 0.026699058711528778




Epoch 57, Batch 30, Loss: 0.0378364659845829




Epoch 57, Batch 40, Loss: 0.03678975626826286




Epoch 57, Batch 50, Loss: 0.034091077744960785




Epoch 57, Batch 60, Loss: 0.031342215836048126




Epoch 57, Batch 70, Loss: 0.022445157170295715




Epoch 57, Batch 80, Loss: 0.017638282850384712




Epoch 57, Batch 90, Loss: 0.022038890048861504


Epoch:  58%|█████▊    | 58/100 [30:42<22:13, 31.75s/it]

Epoch 58, Batch 0, Loss: 0.04022441804409027




Epoch 58, Batch 10, Loss: 0.02053273655474186




Epoch 58, Batch 20, Loss: 0.016694311052560806




Epoch 58, Batch 30, Loss: 0.015047136694192886




Epoch 58, Batch 40, Loss: 0.019577056169509888




Epoch 58, Batch 50, Loss: 0.051992468535900116




Epoch 58, Batch 60, Loss: 0.025764402002096176




Epoch 58, Batch 70, Loss: 0.030162163078784943




Epoch 58, Batch 80, Loss: 0.024228306487202644




Epoch 58, Batch 90, Loss: 0.027932999655604362


Epoch:  59%|█████▉    | 59/100 [31:13<21:40, 31.73s/it]

Epoch 59, Batch 0, Loss: 0.028637737035751343




Epoch 59, Batch 10, Loss: 0.03589056432247162




Epoch 59, Batch 20, Loss: 0.014321699738502502




Epoch 59, Batch 30, Loss: 0.026979461312294006




Epoch 59, Batch 40, Loss: 0.03378481790423393




Epoch 59, Batch 50, Loss: 0.031452812254428864




Epoch 59, Batch 60, Loss: 0.03514912351965904




Epoch 59, Batch 70, Loss: 0.021287450566887856




Epoch 59, Batch 80, Loss: 0.046266067773103714




Epoch 59, Batch 90, Loss: 0.04731845110654831


Epoch:  60%|██████    | 60/100 [31:45<21:06, 31.65s/it]

Epoch 60, Batch 0, Loss: 0.03197160363197327




Epoch 60, Batch 10, Loss: 0.041549794375896454




Epoch 60, Batch 20, Loss: 0.041447438299655914




Epoch 60, Batch 30, Loss: 0.043044015765190125




Epoch 60, Batch 40, Loss: 0.051960814744234085




Epoch 60, Batch 50, Loss: 0.03125961869955063




Epoch 60, Batch 60, Loss: 0.030454250052571297




Epoch 60, Batch 70, Loss: 0.045959848910570145




Epoch 60, Batch 80, Loss: 0.046630166471004486




Epoch 60, Batch 90, Loss: 0.06194256246089935


Epoch:  61%|██████    | 61/100 [32:16<20:33, 31.63s/it]

Epoch 61, Batch 0, Loss: 0.10338909924030304




Epoch 61, Batch 10, Loss: 0.0782690942287445




Epoch 61, Batch 20, Loss: 0.09788849204778671




Epoch 61, Batch 30, Loss: 0.0462954118847847




Epoch 61, Batch 40, Loss: 0.07253822684288025




Epoch 61, Batch 50, Loss: 0.0754106342792511




Epoch 61, Batch 60, Loss: 0.05618438497185707




Epoch 61, Batch 70, Loss: 0.060542602092027664




Epoch 61, Batch 80, Loss: 0.0752580538392067




Epoch 61, Batch 90, Loss: 0.054635100066661835


Epoch:  62%|██████▏   | 62/100 [32:48<20:02, 31.64s/it]

Epoch 62, Batch 0, Loss: 0.07605376094579697




Epoch 62, Batch 10, Loss: 0.0671108067035675




Epoch 62, Batch 20, Loss: 0.06595602631568909




Epoch 62, Batch 30, Loss: 0.08168062567710876




Epoch 62, Batch 40, Loss: 0.05591997504234314




Epoch 62, Batch 50, Loss: 0.07800890505313873




Epoch 62, Batch 60, Loss: 0.074987031519413




Epoch 62, Batch 70, Loss: 0.08481011539697647




Epoch 62, Batch 80, Loss: 0.06790643185377121




Epoch 62, Batch 90, Loss: 0.05564321577548981


Epoch:  63%|██████▎   | 63/100 [33:19<19:29, 31.60s/it]

Epoch 63, Batch 0, Loss: 0.06113773211836815




Epoch 63, Batch 10, Loss: 0.05445011332631111




Epoch 63, Batch 20, Loss: 0.07007608562707901




Epoch 63, Batch 30, Loss: 0.06520332396030426




Epoch 63, Batch 40, Loss: 0.10684303939342499




Epoch 63, Batch 50, Loss: 0.0539180263876915




Epoch 63, Batch 60, Loss: 0.06028998643159866




Epoch 63, Batch 70, Loss: 0.055061206221580505




Epoch 63, Batch 80, Loss: 0.06570129841566086




Epoch 63, Batch 90, Loss: 0.061037544161081314


Epoch:  64%|██████▍   | 64/100 [33:51<18:57, 31.59s/it]

Epoch 64, Batch 0, Loss: 0.07470227777957916




Epoch 64, Batch 10, Loss: 0.04895596578717232




Epoch 64, Batch 20, Loss: 0.07744082063436508




Epoch 64, Batch 30, Loss: 0.08162041008472443




Epoch 64, Batch 40, Loss: 0.03955091908574104




Epoch 64, Batch 50, Loss: 0.03920969367027283




Epoch 64, Batch 60, Loss: 0.07329770177602768




Epoch 64, Batch 70, Loss: 0.08501055836677551




Epoch 64, Batch 80, Loss: 0.06425658613443375




Epoch 64, Batch 90, Loss: 0.060678016394376755


Epoch:  65%|██████▌   | 65/100 [34:23<18:26, 31.60s/it]

Epoch 65, Batch 0, Loss: 0.060058582574129105




Epoch 65, Batch 10, Loss: 0.041912443935871124




Epoch 65, Batch 20, Loss: 0.08344577252864838




Epoch 65, Batch 30, Loss: 0.059047017246484756




Epoch 65, Batch 40, Loss: 0.015036888420581818




Epoch 65, Batch 50, Loss: 0.013424184173345566




Epoch 65, Batch 60, Loss: 0.01776011846959591




Epoch 65, Batch 70, Loss: 0.016725318506360054




Epoch 65, Batch 80, Loss: 0.025776587426662445




Epoch 65, Batch 90, Loss: 0.019604751840233803


Epoch:  66%|██████▌   | 66/100 [34:54<17:54, 31.60s/it]

Epoch 66, Batch 0, Loss: 0.022882742807269096




Epoch 66, Batch 10, Loss: 0.02621527947485447




Epoch 66, Batch 20, Loss: 0.029387440532445908




Epoch 66, Batch 30, Loss: 0.024519508704543114




Epoch 66, Batch 40, Loss: 0.023042907938361168




Epoch 66, Batch 50, Loss: 0.021360013633966446




Epoch 66, Batch 60, Loss: 0.020419592037796974




Epoch 66, Batch 70, Loss: 0.028725994750857353




Epoch 66, Batch 80, Loss: 0.026742208749055862




Epoch 66, Batch 90, Loss: 0.02774941921234131


Epoch:  67%|██████▋   | 67/100 [35:26<17:23, 31.61s/it]

Epoch 67, Batch 0, Loss: 0.02213282696902752




Epoch 67, Batch 10, Loss: 0.04088105633854866




Epoch 67, Batch 20, Loss: 0.04540512338280678




Epoch 67, Batch 30, Loss: 0.05836428329348564




Epoch 67, Batch 40, Loss: 0.05416211113333702




Epoch 67, Batch 50, Loss: 0.056858714669942856




Epoch 67, Batch 60, Loss: 0.031190527603030205




Epoch 67, Batch 70, Loss: 0.05337414890527725




Epoch 67, Batch 80, Loss: 0.1669027954339981




Epoch 67, Batch 90, Loss: 0.1022319495677948


Epoch:  68%|██████▊   | 68/100 [35:57<16:51, 31.61s/it]

Epoch 68, Batch 0, Loss: 0.13145752251148224




Epoch 68, Batch 10, Loss: 0.0981396958231926




Epoch 68, Batch 20, Loss: 0.07079947739839554




Epoch 68, Batch 30, Loss: 0.08688502758741379




Epoch 68, Batch 40, Loss: 0.07613057643175125




Epoch 68, Batch 50, Loss: 0.07108421623706818




Epoch 68, Batch 60, Loss: 0.07023828476667404




Epoch 68, Batch 70, Loss: 0.08937575668096542




Epoch 68, Batch 80, Loss: 0.1243695393204689




Epoch 68, Batch 90, Loss: 0.09542956948280334


Epoch:  69%|██████▉   | 69/100 [36:29<16:20, 31.61s/it]

Epoch 69, Batch 0, Loss: 0.12714259326457977




Epoch 69, Batch 10, Loss: 0.10181627422571182




Epoch 69, Batch 20, Loss: 0.07023933529853821




Epoch 69, Batch 30, Loss: 0.06094785034656525




Epoch 69, Batch 40, Loss: 0.08405537903308868




Epoch 69, Batch 50, Loss: 0.05483858287334442




Epoch 69, Batch 60, Loss: 0.056824781000614166




Epoch 69, Batch 70, Loss: 0.05277523025870323




Epoch 69, Batch 80, Loss: 0.04520978778600693




Epoch 69, Batch 90, Loss: 0.05262806639075279


Epoch:  70%|███████   | 70/100 [37:01<15:49, 31.64s/it]

Epoch 70, Batch 0, Loss: 0.04180648550391197




Epoch 70, Batch 10, Loss: 0.06240787357091904




Epoch 70, Batch 20, Loss: 0.061217788606882095




Epoch 70, Batch 30, Loss: 0.06146778538823128




Epoch 70, Batch 40, Loss: 0.07378510385751724




Epoch 70, Batch 50, Loss: 0.03735673055052757




Epoch 70, Batch 60, Loss: 0.06274641305208206




Epoch 70, Batch 70, Loss: 0.04430637136101723




Epoch 70, Batch 80, Loss: 0.03548701852560043




Epoch 70, Batch 90, Loss: 0.033840540796518326


Epoch:  71%|███████   | 71/100 [37:33<15:20, 31.74s/it]

Epoch 71, Batch 0, Loss: 0.045878488570451736




Epoch 71, Batch 10, Loss: 0.0539209209382534




Epoch 71, Batch 20, Loss: 0.03777129575610161




Epoch 71, Batch 30, Loss: 0.04979231581091881




Epoch 71, Batch 40, Loss: 0.045769479125738144




Epoch 71, Batch 50, Loss: 0.04436715319752693




Epoch 71, Batch 60, Loss: 0.042228639125823975




Epoch 71, Batch 70, Loss: 0.04028601944446564




Epoch 71, Batch 80, Loss: 0.03463079035282135




Epoch 71, Batch 90, Loss: 0.05827432870864868


Epoch:  72%|███████▏  | 72/100 [38:06<14:58, 32.09s/it]

Epoch 72, Batch 0, Loss: 0.05190199613571167




Epoch 72, Batch 10, Loss: 0.04207702726125717




Epoch 72, Batch 20, Loss: 0.027425574138760567




Epoch 72, Batch 30, Loss: 0.0536818653345108




Epoch 72, Batch 40, Loss: 0.04195065423846245




Epoch 72, Batch 50, Loss: 0.04960860311985016




Epoch 72, Batch 60, Loss: 0.04964669421315193




Epoch 72, Batch 70, Loss: 0.05596870556473732




Epoch 72, Batch 80, Loss: 0.04220049828290939




Epoch 72, Batch 90, Loss: 0.056744594126939774


Epoch:  73%|███████▎  | 73/100 [38:38<14:30, 32.26s/it]

Epoch 73, Batch 0, Loss: 0.02934359386563301




Epoch 73, Batch 10, Loss: 0.02825499325990677




Epoch 73, Batch 20, Loss: 0.04976918548345566




Epoch 73, Batch 30, Loss: 0.03371090069413185




Epoch 73, Batch 40, Loss: 0.03967956081032753




Epoch 73, Batch 50, Loss: 0.04428694769740105




Epoch 73, Batch 60, Loss: 0.05357813090085983




Epoch 73, Batch 70, Loss: 0.05858563259243965




Epoch 73, Batch 80, Loss: 0.06976937502622604




Epoch 73, Batch 90, Loss: 0.04296066612005234


Epoch:  74%|███████▍  | 74/100 [39:10<13:56, 32.16s/it]

Epoch 74, Batch 0, Loss: 0.026234041899442673




Epoch 74, Batch 10, Loss: 0.059968721121549606




Epoch 74, Batch 20, Loss: 0.03707817196846008




Epoch 74, Batch 30, Loss: 0.034071847796440125




Epoch 74, Batch 40, Loss: 0.025797205045819283




Epoch 74, Batch 50, Loss: 0.029710955917835236




Epoch 74, Batch 60, Loss: 0.03460289165377617




Epoch 74, Batch 70, Loss: 0.05163121968507767




Epoch 74, Batch 80, Loss: 0.02569308876991272




Epoch 74, Batch 90, Loss: 0.036835115402936935


Epoch:  75%|███████▌  | 75/100 [39:42<13:21, 32.06s/it]

Epoch 75, Batch 0, Loss: 0.06834671646356583




Epoch 75, Batch 10, Loss: 0.03111879713833332




Epoch 75, Batch 20, Loss: 0.02680608257651329




Epoch 75, Batch 30, Loss: 0.03707042708992958




Epoch 75, Batch 40, Loss: 0.03445301577448845




Epoch 75, Batch 50, Loss: 0.026676150038838387




Epoch 75, Batch 60, Loss: 0.027807122096419334




Epoch 75, Batch 70, Loss: 0.046652667224407196




Epoch 75, Batch 80, Loss: 0.025796085596084595




Epoch 75, Batch 90, Loss: 0.028193701058626175


Epoch:  76%|███████▌  | 76/100 [40:14<12:48, 32.01s/it]

Epoch 76, Batch 0, Loss: 0.03167036920785904




Epoch 76, Batch 10, Loss: 0.03497723862528801




Epoch 76, Batch 20, Loss: 0.02759263850748539




Epoch 76, Batch 30, Loss: 0.021143002435564995




Epoch 76, Batch 40, Loss: 0.02378050796687603




Epoch 76, Batch 50, Loss: 0.045461710542440414




Epoch 76, Batch 60, Loss: 0.02375439926981926




Epoch 76, Batch 70, Loss: 0.018694499507546425




Epoch 76, Batch 80, Loss: 0.016588961705565453




Epoch 76, Batch 90, Loss: 0.05227358266711235


Epoch:  77%|███████▋  | 77/100 [40:46<12:15, 31.97s/it]

Epoch 77, Batch 0, Loss: 0.030619969591498375




Epoch 77, Batch 10, Loss: 0.02563767321407795




Epoch 77, Batch 20, Loss: 0.024050328880548477




Epoch 77, Batch 30, Loss: 0.022509058937430382




Epoch 77, Batch 40, Loss: 0.019872939214110374




Epoch 77, Batch 50, Loss: 0.016834858804941177




Epoch 77, Batch 60, Loss: 0.022951750084757805




Epoch 77, Batch 70, Loss: 0.02727321721613407




Epoch 77, Batch 80, Loss: 0.0187221672385931




Epoch 77, Batch 90, Loss: 0.0263135377317667


Epoch:  78%|███████▊  | 78/100 [41:18<11:42, 31.91s/it]

Epoch 78, Batch 0, Loss: 0.021881889551877975




Epoch 78, Batch 10, Loss: 0.022798743098974228




Epoch 78, Batch 20, Loss: 0.014911274425685406




Epoch 78, Batch 30, Loss: 0.026761207729578018




Epoch 78, Batch 40, Loss: 0.02369818277657032




Epoch 78, Batch 50, Loss: 0.022769393399357796




Epoch 78, Batch 60, Loss: 0.016371358186006546




Epoch 78, Batch 70, Loss: 0.026031596586108208




Epoch 78, Batch 80, Loss: 0.02113904058933258




Epoch 78, Batch 90, Loss: 0.024079157039523125


Epoch:  79%|███████▉  | 79/100 [41:50<11:12, 32.04s/it]

Epoch 79, Batch 0, Loss: 0.03897907957434654




Epoch 79, Batch 10, Loss: 0.03562823310494423




Epoch 79, Batch 20, Loss: 0.019990015774965286




Epoch 79, Batch 30, Loss: 0.020788175985217094




Epoch 79, Batch 40, Loss: 0.02716199681162834




Epoch 79, Batch 50, Loss: 0.018106356263160706




Epoch 79, Batch 60, Loss: 0.024881822988390923




Epoch 79, Batch 70, Loss: 0.021642174571752548




Epoch 79, Batch 80, Loss: 0.012338396161794662




Epoch 79, Batch 90, Loss: 0.025310944765806198


Epoch:  80%|████████  | 80/100 [42:23<10:48, 32.43s/it]

Epoch 80, Batch 0, Loss: 0.01839287020266056




Epoch 80, Batch 10, Loss: 0.019102424383163452




Epoch 80, Batch 20, Loss: 0.023471763357520103




Epoch 80, Batch 30, Loss: 0.019310643896460533




Epoch 80, Batch 40, Loss: 0.010065915994346142




Epoch 80, Batch 50, Loss: 0.01423213817179203




Epoch 80, Batch 60, Loss: 0.017968980595469475




Epoch 80, Batch 70, Loss: 0.015859108418226242




Epoch 80, Batch 80, Loss: 0.013296937569975853




Epoch 80, Batch 90, Loss: 0.014103841036558151


Epoch:  81%|████████  | 81/100 [42:55<10:14, 32.35s/it]

Epoch 81, Batch 0, Loss: 0.01403124164789915




Epoch 81, Batch 10, Loss: 0.017752064391970634




Epoch 81, Batch 20, Loss: 0.024078965187072754




Epoch 81, Batch 30, Loss: 0.013234622776508331




Epoch 81, Batch 40, Loss: 0.028748592361807823




Epoch 81, Batch 50, Loss: 0.026421701535582542




Epoch 81, Batch 60, Loss: 0.020773502066731453




Epoch 81, Batch 70, Loss: 0.029291348531842232




Epoch 81, Batch 80, Loss: 0.01115849893540144




Epoch 81, Batch 90, Loss: 0.020769212394952774


Epoch:  82%|████████▏ | 82/100 [43:27<09:39, 32.21s/it]

Epoch 82, Batch 0, Loss: 0.013409932143986225




Epoch 82, Batch 10, Loss: 0.020462770015001297




Epoch 82, Batch 20, Loss: 0.019926857203245163




Epoch 82, Batch 30, Loss: 0.0166784655302763




Epoch 82, Batch 40, Loss: 0.013751549646258354




Epoch 82, Batch 50, Loss: 0.0074525498785078526




Epoch 82, Batch 60, Loss: 0.014707074500620365




Epoch 82, Batch 70, Loss: 0.008785207755863667




Epoch 82, Batch 80, Loss: 0.011897829361259937




Epoch 82, Batch 90, Loss: 0.015929941087961197


Epoch:  83%|████████▎ | 83/100 [43:59<09:04, 32.02s/it]

Epoch 83, Batch 0, Loss: 0.013242963701486588




Epoch 83, Batch 10, Loss: 0.015777811408042908




Epoch 83, Batch 20, Loss: 0.004890414886176586




Epoch 83, Batch 30, Loss: 0.007019421085715294




Epoch 83, Batch 40, Loss: 0.015177060849964619




Epoch 83, Batch 50, Loss: 0.012601527385413647




Epoch 83, Batch 60, Loss: 0.013519050553441048




Epoch 83, Batch 70, Loss: 0.012937117367982864




Epoch 83, Batch 80, Loss: 0.011874658986926079




Epoch 83, Batch 90, Loss: 0.0074814786203205585


Epoch:  84%|████████▍ | 84/100 [44:31<08:32, 32.01s/it]

Epoch 84, Batch 0, Loss: 0.01038152165710926




Epoch 84, Batch 10, Loss: 0.006278098560869694




Epoch 84, Batch 20, Loss: 0.01095800381153822




Epoch 84, Batch 30, Loss: 0.008718216791749




Epoch 84, Batch 40, Loss: 0.01191967073827982




Epoch 84, Batch 50, Loss: 0.018537795171141624




Epoch 84, Batch 60, Loss: 0.008298685774207115




Epoch 84, Batch 70, Loss: 0.015069138258695602




Epoch 84, Batch 80, Loss: 0.011231176555156708




Epoch 84, Batch 90, Loss: 0.014880308881402016


Epoch:  85%|████████▌ | 85/100 [45:03<07:58, 31.92s/it]

Epoch 85, Batch 0, Loss: 0.014013644307851791




Epoch 85, Batch 10, Loss: 0.010010828264057636




Epoch 85, Batch 20, Loss: 0.010701127350330353




Epoch 85, Batch 30, Loss: 0.009927413426339626




Epoch 85, Batch 40, Loss: 0.010493568144738674




Epoch 85, Batch 50, Loss: 0.007668422069400549




Epoch 85, Batch 60, Loss: 0.00854824110865593




Epoch 85, Batch 70, Loss: 0.012232830747961998




Epoch 85, Batch 80, Loss: 0.007418727036565542




Epoch 85, Batch 90, Loss: 0.016256460919976234


Epoch:  86%|████████▌ | 86/100 [45:34<07:25, 31.80s/it]

Epoch 86, Batch 0, Loss: 0.01312908623367548




Epoch 86, Batch 10, Loss: 0.011467295698821545




Epoch 86, Batch 20, Loss: 0.004175766371190548




Epoch 86, Batch 30, Loss: 0.0076423706486821175




Epoch 86, Batch 40, Loss: 0.004981360863894224




Epoch 86, Batch 50, Loss: 0.0036157865542918444




Epoch 86, Batch 60, Loss: 0.00849730335175991




Epoch 86, Batch 70, Loss: 0.00576033815741539




Epoch 86, Batch 80, Loss: 0.004859162028878927




Epoch 86, Batch 90, Loss: 0.012253482826054096


Epoch:  87%|████████▋ | 87/100 [46:06<06:52, 31.75s/it]

Epoch 87, Batch 0, Loss: 0.007708454970270395




Epoch 87, Batch 10, Loss: 0.006939450278878212




Epoch 87, Batch 20, Loss: 0.004244954325258732




Epoch 87, Batch 30, Loss: 0.010858193039894104




Epoch 87, Batch 40, Loss: 0.009062803350389004




Epoch 87, Batch 50, Loss: 0.008193704299628735




Epoch 87, Batch 60, Loss: 0.0029882616363465786




Epoch 87, Batch 70, Loss: 0.006046729162335396




Epoch 87, Batch 80, Loss: 0.0044098347425460815




Epoch 87, Batch 90, Loss: 0.005961727816611528


Epoch:  88%|████████▊ | 88/100 [46:37<06:20, 31.70s/it]

Epoch 88, Batch 0, Loss: 0.005968322977423668




Epoch 88, Batch 10, Loss: 0.005481491796672344




Epoch 88, Batch 20, Loss: 0.004266311880201101




Epoch 88, Batch 30, Loss: 0.005784132052212954




Epoch 88, Batch 40, Loss: 0.0030560491140931845




Epoch 88, Batch 50, Loss: 0.005342903081327677




Epoch 88, Batch 60, Loss: 0.008519223891198635




Epoch 88, Batch 70, Loss: 0.01592395082116127




Epoch 88, Batch 80, Loss: 0.013158011250197887




Epoch 88, Batch 90, Loss: 0.007185420952737331


Epoch:  89%|████████▉ | 89/100 [47:09<05:48, 31.67s/it]

Epoch 89, Batch 0, Loss: 0.013958401046693325




Epoch 89, Batch 10, Loss: 0.005265415646135807




Epoch 89, Batch 20, Loss: 0.0036481700371950865




Epoch 89, Batch 30, Loss: 0.00801887083798647




Epoch 89, Batch 40, Loss: 0.004744189791381359




Epoch 89, Batch 50, Loss: 0.0038459738716483116




Epoch 89, Batch 60, Loss: 0.006070858333259821




Epoch 89, Batch 70, Loss: 0.0075088259764015675




Epoch 89, Batch 80, Loss: 0.0024154973216354847




Epoch 89, Batch 90, Loss: 0.006295537576079369


Epoch:  90%|█████████ | 90/100 [47:41<05:16, 31.68s/it]

Epoch 90, Batch 0, Loss: 0.00923341978341341




Epoch 90, Batch 10, Loss: 0.005537156015634537




Epoch 90, Batch 20, Loss: 0.0024931791704148054




Epoch 90, Batch 30, Loss: 0.006348148453980684




Epoch 90, Batch 40, Loss: 0.0030181787442415953




Epoch 90, Batch 50, Loss: 0.0025040304753929377




Epoch 90, Batch 60, Loss: 0.008943844586610794




Epoch 90, Batch 70, Loss: 0.004098467528820038




Epoch 90, Batch 80, Loss: 0.00450856052339077




Epoch 90, Batch 90, Loss: 0.006081086117774248


Epoch:  91%|█████████ | 91/100 [48:13<04:45, 31.76s/it]

Epoch 91, Batch 0, Loss: 0.0025687734596431255




Epoch 91, Batch 10, Loss: 0.002546780975535512




Epoch 91, Batch 20, Loss: 0.007245756685733795




Epoch 91, Batch 30, Loss: 0.003962195478379726




Epoch 91, Batch 40, Loss: 0.002674092771485448




Epoch 91, Batch 50, Loss: 0.0018059704452753067




Epoch 91, Batch 60, Loss: 0.0024596655275672674




Epoch 91, Batch 70, Loss: 0.007377460598945618




Epoch 91, Batch 80, Loss: 0.001282864366658032




Epoch 91, Batch 90, Loss: 0.002171096159145236


Epoch:  92%|█████████▏| 92/100 [48:44<04:14, 31.78s/it]

Epoch 92, Batch 0, Loss: 0.0036010348703712225




Epoch 92, Batch 10, Loss: 0.0019388683140277863




Epoch 92, Batch 20, Loss: 0.0032671408262103796




Epoch 92, Batch 30, Loss: 0.008852814324200153




Epoch 92, Batch 40, Loss: 0.004608485382050276




Epoch 92, Batch 50, Loss: 0.0018593883141875267




Epoch 92, Batch 60, Loss: 0.002334450837224722




Epoch 92, Batch 70, Loss: 0.0030833191704005003




Epoch 92, Batch 80, Loss: 0.001162656699307263




Epoch 92, Batch 90, Loss: 0.0025249216705560684


Epoch:  93%|█████████▎| 93/100 [49:16<03:42, 31.85s/it]

Epoch 93, Batch 0, Loss: 0.004318303894251585




Epoch 93, Batch 10, Loss: 0.006345716305077076




Epoch 93, Batch 20, Loss: 0.0061251139268279076




Epoch 93, Batch 30, Loss: 0.006072103511542082




Epoch 93, Batch 40, Loss: 0.0021406803280115128




Epoch 93, Batch 50, Loss: 0.0009851378854364157




Epoch 93, Batch 60, Loss: 0.0021985850762575865




Epoch 93, Batch 70, Loss: 0.002591464202851057




Epoch 93, Batch 80, Loss: 0.0017335667507722974




Epoch 93, Batch 90, Loss: 0.002964376239106059


Epoch:  94%|█████████▍| 94/100 [49:48<03:10, 31.79s/it]

Epoch 94, Batch 0, Loss: 0.002305240137502551




Epoch 94, Batch 10, Loss: 0.0029834769666194916




Epoch 94, Batch 20, Loss: 0.00480224471539259




Epoch 94, Batch 30, Loss: 0.0032894578762352467




Epoch 94, Batch 40, Loss: 0.0026512080803513527




Epoch 94, Batch 50, Loss: 0.0028896541334688663




Epoch 94, Batch 60, Loss: 0.001584616955369711




Epoch 94, Batch 70, Loss: 0.0017750143306329846




Epoch 94, Batch 80, Loss: 0.0011752875288948417




Epoch 94, Batch 90, Loss: 0.0028364236932247877


Epoch:  95%|█████████▌| 95/100 [50:20<02:38, 31.76s/it]

Epoch 95, Batch 0, Loss: 0.0005809020367451012




Epoch 95, Batch 10, Loss: 0.0015043439343571663




Epoch 95, Batch 20, Loss: 0.0025832275860011578




Epoch 95, Batch 30, Loss: 0.0026819207705557346




Epoch 95, Batch 40, Loss: 0.0036146172787994146




Epoch 95, Batch 50, Loss: 0.0026331418193876743




Epoch 95, Batch 60, Loss: 0.0022358945570886135




Epoch 95, Batch 70, Loss: 0.0017641664016991854




Epoch 95, Batch 80, Loss: 0.008777720853686333




Epoch 95, Batch 90, Loss: 0.0030845007859170437


Epoch:  96%|█████████▌| 96/100 [50:51<02:06, 31.72s/it]

Epoch 96, Batch 0, Loss: 0.001074683852493763




Epoch 96, Batch 10, Loss: 0.001553395763039589




Epoch 96, Batch 20, Loss: 0.0014797095209360123




Epoch 96, Batch 30, Loss: 0.0017362299840897322




Epoch 96, Batch 40, Loss: 0.0007967468118295074




Epoch 96, Batch 50, Loss: 0.0008390507427975535




Epoch 96, Batch 60, Loss: 0.0005269543034955859




Epoch 96, Batch 70, Loss: 0.002777493791654706




Epoch 96, Batch 80, Loss: 0.003947874531149864




Epoch 96, Batch 90, Loss: 0.002795059932395816


Epoch:  97%|█████████▋| 97/100 [51:23<01:35, 31.69s/it]

Epoch 97, Batch 0, Loss: 0.004439090844243765




Epoch 97, Batch 10, Loss: 0.0012773595517501235




Epoch 97, Batch 20, Loss: 0.0013557150959968567




Epoch 97, Batch 30, Loss: 0.0014122948050498962




Epoch 97, Batch 40, Loss: 0.004370799753814936




Epoch 97, Batch 50, Loss: 0.0028495590668171644




Epoch 97, Batch 60, Loss: 0.0011262035695835948




Epoch 97, Batch 70, Loss: 0.00457643810659647




Epoch 97, Batch 80, Loss: 0.002392886206507683




Epoch 97, Batch 90, Loss: 0.003528483444824815


Epoch:  98%|█████████▊| 98/100 [51:55<01:03, 31.78s/it]

Epoch 98, Batch 0, Loss: 0.0013632297050207853




Epoch 98, Batch 10, Loss: 0.0018629482947289944




Epoch 98, Batch 20, Loss: 0.0009041162556968629




Epoch 98, Batch 30, Loss: 0.0007462063804268837




Epoch 98, Batch 40, Loss: 0.0006233271560631692




Epoch 98, Batch 50, Loss: 0.0020601071882992983




Epoch 98, Batch 60, Loss: 0.0007102392264641821




Epoch 98, Batch 70, Loss: 0.0005921609117649496




Epoch 98, Batch 80, Loss: 0.0017807057593017817




Epoch 98, Batch 90, Loss: 0.0005712839192710817


Epoch:  99%|█████████▉| 99/100 [52:28<00:32, 32.02s/it]

Epoch 99, Batch 0, Loss: 0.0008252875413745642




Epoch 99, Batch 10, Loss: 0.00290026911534369




Epoch 99, Batch 20, Loss: 0.0025920828338712454




Epoch 99, Batch 30, Loss: 0.003507895627990365




Epoch 99, Batch 40, Loss: 0.0010184973943978548




Epoch 99, Batch 50, Loss: 0.0018006853060796857




Epoch 99, Batch 60, Loss: 0.0016487385146319866




Epoch 99, Batch 70, Loss: 0.0024415033403784037




Epoch 99, Batch 80, Loss: 0.002229855628684163




Epoch 99, Batch 90, Loss: 0.0007555995834991336


Epoch: 100%|██████████| 100/100 [53:00<00:00, 31.81s/it]


In [9]:
#@markdown ### **Loading Pretrained Checkpoint**
#@markdown Set `load_pretrained = True` to load pretrained weights.


ckpt_path = "ckpts/model_checkpoint_6_0_6.084841061237967e-06.pth"

state_dict = torch.load(ckpt_path, map_location='cuda')
noise_pred_net.load_state_dict(state_dict['model_state_dict'])
imm.model = noise_pred_net 

In [23]:
#@markdown ### **Inference**

# limit enviornment interaction to 200 steps before termination
max_steps = 200
env = PushTEnv()
# use a seed >200 to avoid initial states seen in the training dataset
env.seed(100000)

# get first observation
obs, info = env.reset()

# keep a queue of last 2 steps of observations
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTStateEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)
        # normalize observation
        nobs = normalize_data(obs_seq, stats=stats['obs'])
        # device transfer
        nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)

        # infer action
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            r = imm.sample(shape=(1, pred_horizon, action_dim), steps=10, global_cond=obs_cond)
            # for k in noise_scheduler.timesteps:
            #     # predict noise
            #     noise_pred = ema_noise_pred_net(
            #         sample=naction,
            #         timestep=k,
            #         global_cond=obs_cond
            #     )

            #     # inverse diffusion step (remove noise)
            #     naction = noise_scheduler.step(
            #         model_output=noise_pred,
            #         timestep=k,
            #         sample=naction
            #     ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break

# print out the maximum target coverage
print('Score: ', max(rewards))

# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTStateEnv: 201it [00:01, 191.77it/s, reward=0]                         


Score:  0.0
