In [1]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch torchvision diffusers \
scikit-image scikit-video zarr numcodecs \
pygame pymunk gym shapely opencv-python

Python 3.10.12
Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl (766.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.7/766.7 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting diffusers
  Downloading diffusers-0.33.1-py3-none-any.whl (3.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m90.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hCollecting scikit-image
  Downloading scikit_image-0.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.8/14.8 MB[0m [31m84.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting scikit-video
 

In [1]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import torch
import torch.nn as nn
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#@markdown ### **Environment**
#@markdown Defines a PyMunk-based Push-T environment `PushTEnv`.
#@markdown
#@markdown **Goal**: push the gray T-block into the green area.
#@markdown
#@markdown Adapted from [Implicit Behavior Cloning](https://implicitbc.github.io/)


positive_y_is_up: bool = False
"""Make increasing values of y point upwards.

When True::

    y
    ^
    |      . (3, 3)
    |
    |   . (2, 2)
    |
    +------ > x

When False::

    +------ > x
    |
    |   . (2, 2)
    |
    |      . (3, 3)
    v
    y

"""

def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
    """Convenience method to convert pymunk coordinates to pygame surface
    local coordinates.

    Note that in case positive_y_is_up is False, this function wont actually do
    anything except converting the point to integers.
    """
    if positive_y_is_up:
        return round(p[0]), surface.get_height() - round(p[1])
    else:
        return round(p[0]), round(p[1])


def light_color(color: SpaceDebugColor):
    color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
    color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
    return color

class DrawOptions(pymunk.SpaceDebugDrawOptions):
    def __init__(self, surface: pygame.Surface) -> None:
        """Draw a pymunk.Space on a pygame.Surface object.

        Typical usage::

        >>> import pymunk
        >>> surface = pygame.Surface((10,10))
        >>> space = pymunk.Space()
        >>> options = pymunk.pygame_util.DrawOptions(surface)
        >>> space.debug_draw(options)

        You can control the color of a shape by setting shape.color to the color
        you want it drawn in::

        >>> c = pymunk.Circle(None, 10)
        >>> c.color = pygame.Color("pink")

        See pygame_util.demo.py for a full example

        Since pygame uses a coordiante system where y points down (in contrast
        to many other cases), you either have to make the physics simulation
        with Pymunk also behave in that way, or flip everything when you draw.

        The easiest is probably to just make the simulation behave the same
        way as Pygame does. In that way all coordinates used are in the same
        orientation and easy to reason about::

        >>> space = pymunk.Space()
        >>> space.gravity = (0, -1000)
        >>> body = pymunk.Body()
        >>> body.position = (0, 0) # will be positioned in the top left corner
        >>> space.debug_draw(options)

        To flip the drawing its possible to set the module property
        :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
        the simulation upside down before drawing::

        >>> positive_y_is_up = True
        >>> body = pymunk.Body()
        >>> body.position = (0, 0)
        >>> # Body will be position in bottom left corner

        :Parameters:
                surface : pygame.Surface
                    Surface that the objects will be drawn on
        """
        self.surface = surface
        super(DrawOptions, self).__init__()

    def draw_circle(
        self,
        pos: Vec2d,
        angle: float,
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p = to_pygame(pos, self.surface)

        pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
        pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)

        circle_edge = pos + Vec2d(radius, 0).rotated(angle)
        p2 = to_pygame(circle_edge, self.surface)
        line_r = 2 if radius > 20 else 1
        # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)

    def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])

    def draw_fat_segment(
        self,
        a: Tuple[float, float],
        b: Tuple[float, float],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        p1 = to_pygame(a, self.surface)
        p2 = to_pygame(b, self.surface)

        r = round(max(1, radius * 2))
        pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
        if r > 2:
            orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
            if orthog[0] == 0 and orthog[1] == 0:
                return
            scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
            orthog[0] = round(orthog[0] * scale)
            orthog[1] = round(orthog[1] * scale)
            points = [
                (p1[0] - orthog[0], p1[1] - orthog[1]),
                (p1[0] + orthog[0], p1[1] + orthog[1]),
                (p2[0] + orthog[0], p2[1] + orthog[1]),
                (p2[0] - orthog[0], p2[1] - orthog[1]),
            ]
            pygame.draw.polygon(self.surface, fill_color.as_int(), points)
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p1[0]), round(p1[1])),
                round(radius),
            )
            pygame.draw.circle(
                self.surface,
                fill_color.as_int(),
                (round(p2[0]), round(p2[1])),
                round(radius),
            )

    def draw_polygon(
        self,
        verts: Sequence[Tuple[float, float]],
        radius: float,
        outline_color: SpaceDebugColor,
        fill_color: SpaceDebugColor,
    ) -> None:
        ps = [to_pygame(v, self.surface) for v in verts]
        ps += [ps[0]]

        radius = 2
        pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)

        if radius > 0:
            for i in range(len(verts)):
                a = verts[i]
                b = verts[(i + 1) % len(verts)]
                self.draw_fat_segment(a, b, radius, fill_color, fill_color)

    def draw_dot(
        self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
    ) -> None:
        p = to_pygame(pos, self.surface)
        pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)


def pymunk_to_shapely(body, shapes):
    geoms = list()
    for shape in shapes:
        if isinstance(shape, pymunk.shapes.Poly):
            verts = [body.local_to_world(v) for v in shape.get_vertices()]
            verts += [verts[0]]
            geoms.append(sg.Polygon(verts))
        else:
            raise RuntimeError(f'Unsupported shape type {type(shape)}')
    geom = sg.MultiPolygon(geoms)
    return geom

# env
class PushTEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
    reward_range = (0., 1.)

    def __init__(self,
            legacy=False,
            block_cog=None, damping=None,
            render_action=True,
            render_size=96,
            reset_to_state=None
        ):
        self._seed = None
        self.seed()
        self.window_size = ws = 512  # The size of the PyGame window
        self.render_size = render_size
        self.sim_hz = 100
        # Local controller params.
        self.k_p, self.k_v = 100, 20    # PD control.z
        self.control_hz = self.metadata['video.frames_per_second']
        # legcay set_state for data compatiblity
        self.legacy = legacy

        # agent_pos, block_pos, block_angle
        self.observation_space = spaces.Box(
            low=np.array([0,0,0,0,0], dtype=np.float64),
            high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64),
            shape=(5,),
            dtype=np.float64
        )

        # positional goal for agent
        self.action_space = spaces.Box(
            low=np.array([0,0], dtype=np.float64),
            high=np.array([ws,ws], dtype=np.float64),
            shape=(2,),
            dtype=np.float64
        )

        self.block_cog = block_cog
        self.damping = damping
        self.render_action = render_action

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
        self.screen = None

        self.space = None
        self.teleop = None
        self.render_buffer = None
        self.latest_action = None
        self.reset_to_state = reset_to_state

    def reset(self):
        seed = self._seed
        self._setup()
        if self.block_cog is not None:
            self.block.center_of_gravity = self.block_cog
        if self.damping is not None:
            self.space.damping = self.damping

        # use legacy RandomState for compatiblity
        state = self.reset_to_state
        if state is None:
            rs = np.random.RandomState(seed=seed)
            state = np.array([
                rs.randint(50, 450), rs.randint(50, 450),
                rs.randint(100, 400), rs.randint(100, 400),
                rs.randn() * 2 * np.pi - np.pi
                ])
        self._set_state(state)

        obs = self._get_obs()
        info = self._get_info()
        return obs, info

    def step(self, action):
        dt = 1.0 / self.sim_hz
        self.n_contact_points = 0
        n_steps = self.sim_hz // self.control_hz
        if action is not None:
            self.latest_action = action
            for i in range(n_steps):
                # Step PD control.
                # self.agent.velocity = self.k_p * (act - self.agent.position)    # P control works too.
                acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity)
                self.agent.velocity += acceleration * dt

                # Step physics.
                self.space.step(dt)

        # compute reward
        goal_body = self._get_goal_pose_body(self.goal_pose)
        goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
        block_geom = pymunk_to_shapely(self.block, self.block.shapes)

        intersection_area = goal_geom.intersection(block_geom).area
        goal_area = goal_geom.area
        coverage = intersection_area / goal_area
        reward = np.clip(coverage / self.success_threshold, 0, 1)
        done = coverage > self.success_threshold
        terminated = done
        truncated = done

        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

    def render(self, mode):
        return self._render_frame(mode)

    def teleop_agent(self):
        TeleopAgent = collections.namedtuple('TeleopAgent', ['act'])
        def act(obs):
            act = None
            mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
            if self.teleop or (mouse_position - self.agent.position).length < 30:
                self.teleop = True
                act = mouse_position
            return act
        return TeleopAgent(act)

    def _get_obs(self):
        obs = np.array(
            tuple(self.agent.position) \
            + tuple(self.block.position) \
            + (self.block.angle % (2 * np.pi),))
        return obs

    def _get_goal_pose_body(self, pose):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (50, 100))
        body = pymunk.Body(mass, inertia)
        # preserving the legacy assignment order for compatibility
        # the order here dosn't matter somehow, maybe because CoM is aligned with body origin
        body.position = pose[:2].tolist()
        body.angle = pose[2]
        return body

    def _get_info(self):
        n_steps = self.sim_hz // self.control_hz
        n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
        info = {
            'pos_agent': np.array(self.agent.position),
            'vel_agent': np.array(self.agent.velocity),
            'block_pose': np.array(list(self.block.position) + [self.block.angle]),
            'goal_pose': self.goal_pose,
            'n_contacts': n_contact_points_per_step}
        return info

    def _render_frame(self, mode):

        if self.window is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        self.screen = canvas

        draw_options = DrawOptions(canvas)

        # Draw goal pose.
        goal_body = self._get_goal_pose_body(self.goal_pose)
        for shape in self.block.shapes:
            goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()]
            goal_points += [goal_points[0]]
            pygame.draw.polygon(canvas, self.goal_color, goal_points)

        # Draw agent and block.
        self.space.debug_draw(draw_options)

        if mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # the clock is aleady ticked during in step for "human"


        img = np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        img = cv2.resize(img, (self.render_size, self.render_size))
        if self.render_action:
            if self.render_action and (self.latest_action is not None):
                action = np.array(self.latest_action)
                coord = (action / 512 * 96).astype(np.int32)
                marker_size = int(8/96*self.render_size)
                thickness = int(1/96*self.render_size)
                cv2.drawMarker(img, coord,
                    color=(255,0,0), markerType=cv2.MARKER_CROSS,
                    markerSize=marker_size, thickness=thickness)
        return img


    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def seed(self, seed=None):
        if seed is None:
            seed = np.random.randint(0,25536)
        self._seed = seed
        self.np_random = np.random.default_rng(seed)

    def _handle_collision(self, arbiter, space, data):
        self.n_contact_points += len(arbiter.contact_point_set.points)

    def _set_state(self, state):
        if isinstance(state, np.ndarray):
            state = state.tolist()
        pos_agent = state[:2]
        pos_block = state[2:4]
        rot_block = state[4]
        self.agent.position = pos_agent
        # setting angle rotates with respect to center of mass
        # therefore will modify the geometric position
        # if not the same as CoM
        # therefore should be modified first.
        if self.legacy:
            # for compatiblity with legacy data
            self.block.position = pos_block
            self.block.angle = rot_block
        else:
            self.block.angle = rot_block
            self.block.position = pos_block

        # Run physics to take effect
        self.space.step(1.0 / self.sim_hz)

    def _set_state_local(self, state_local):
        agent_pos_local = state_local[:2]
        block_pose_local = state_local[2:]
        tf_img_obj = st.AffineTransform(
            translation=self.goal_pose[:2],
            rotation=self.goal_pose[2])
        tf_obj_new = st.AffineTransform(
            translation=block_pose_local[:2],
            rotation=block_pose_local[2]
        )
        tf_img_new = st.AffineTransform(
            matrix=tf_img_obj.params @ tf_obj_new.params
        )
        agent_pos_new = tf_img_new(agent_pos_local)
        new_state = np.array(
            list(agent_pos_new[0]) + list(tf_img_new.translation) \
                + [tf_img_new.rotation])
        self._set_state(new_state)
        return new_state

    def _setup(self):
        self.space = pymunk.Space()
        self.space.gravity = 0, 0
        self.space.damping = 0
        self.teleop = False
        self.render_buffer = list()

        # Add walls.
        walls = [
            self._add_segment((5, 506), (5, 5), 2),
            self._add_segment((5, 5), (506, 5), 2),
            self._add_segment((506, 5), (506, 506), 2),
            self._add_segment((5, 506), (506, 506), 2)
        ]
        self.space.add(*walls)

        # Add agent, block, and goal zone.
        self.agent = self.add_circle((256, 400), 15)
        self.block = self.add_tee((256, 300), 0)
        self.goal_color = pygame.Color('LightGreen')
        self.goal_pose = np.array([256,256,np.pi/4])  # x, y, theta (in radians)

        # Add collision handeling
        self.collision_handeler = self.space.add_collision_handler(0, 0)
        self.collision_handeler.post_solve = self._handle_collision
        self.n_contact_points = 0

        self.max_score = 50 * 100
        self.success_threshold = 0.95    # 95% coverage.

    def _add_segment(self, a, b, radius):
        shape = pymunk.Segment(self.space.static_body, a, b, radius)
        shape.color = pygame.Color('LightGray')    # https://htmlcolorcodes.com/color-names
        return shape

    def add_circle(self, position, radius):
        body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
        body.position = position
        body.friction = 1
        shape = pymunk.Circle(body, radius)
        shape.color = pygame.Color('RoyalBlue')
        self.space.add(body, shape)
        return body

    def add_box(self, position, height, width):
        mass = 1
        inertia = pymunk.moment_for_box(mass, (height, width))
        body = pymunk.Body(mass, inertia)
        body.position = position
        shape = pymunk.Poly.create_box(body, (height, width))
        shape.color = pygame.Color('LightSlateGray')
        self.space.add(body, shape)
        return body

    def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()):
        mass = 1
        length = 4
        vertices1 = [(-length*scale/2, scale),
                                 ( length*scale/2, scale),
                                 ( length*scale/2, 0),
                                 (-length*scale/2, 0)]
        inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
        vertices2 = [(-scale/2, scale),
                                 (-scale/2, length*scale),
                                 ( scale/2, length*scale),
                                 ( scale/2, scale)]
        inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
        body = pymunk.Body(mass, inertia1 + inertia2)
        shape1 = pymunk.Poly(body, vertices1)
        shape2 = pymunk.Poly(body, vertices2)
        shape1.color = pygame.Color(color)
        shape2.color = pygame.Color(color)
        shape1.filter = pymunk.ShapeFilter(mask=mask)
        shape2.filter = pymunk.ShapeFilter(mask=mask)
        body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
        body.position = position
        body.angle = angle
        body.friction = 1
        self.space.add(body, shape1, shape2)
        return body


In [3]:
# from huggingface_hub.utils import IGNORE_GIT_FOLDER_PATTERNS
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, IGNORE_GIT_FOLDER_PATTERNS = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("Obs: ", repr(obs))
    print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
    print("Action: ", repr(action))
    print("Action:   [target_agent_x, target_agent_y]")

Obs:  array([142.7087, 231.8552, 292.    , 351.    ,   2.9196])
Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]
Action:  array([156.3075, 495.9264])
Action:   [target_agent_x, target_agent_y]


In [4]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTStateDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data (obs, action) from a zarr storage
#@markdown - Normalizes each dimension of obs and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `obs`: shape (obs_horizon, obs_dim)
#@markdown  - key `action`: shape (pred_horizon, action_dim)

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0),
        'std': np.std(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class PushTStateDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path,
                 pred_horizon, obs_horizon, action_horizon):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
        # All demonstration episodes are concatinated in the first dimension N
        train_data = {
            # (N, action_dim)
            'action': dataset_root['data']['action'][:],
            # (N, obs_dim)
            'obs': dataset_root['data']['state'][:]
        }
        # Marks one-past the last index for each episode
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            # add padding such that each timestep in the dataset are seen
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])
            stats[key+'_normalized'] = get_data_stats(normalized_train_data[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        # all possible segments of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['obs'] = nsample['obs'][:self.obs_horizon,:]
        return nsample


In [5]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTStateDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

batch['obs'].shape: torch.Size([256, 2, 5])
batch['action'].shape torch.Size([256, 16, 2])


In [6]:
# dataset size
print(len(dataset))
stats['action_normalized']['std'].mean()

24208


np.float32(0.40121755)

In [14]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

# this is the unet for the robotics problem, modify this to use inductive moment matching
# IMM

'''
η_t = t/α_t.
OT-FM schedule: α_t = 1-t, σ_t = t
x_t = α_t·x + σ_t·ε # ε ~ N(0, I)
cout(t) = −t·σ_d

'''

import math
import torch
import torch.nn as nn
from typing import Union
from tqdm import tqdm
class IMMloss(nn.Module):
    """
    IMM loss function using the Laplace kernel.
    """

    def __init__(self, obs_horizon, pred_horizon, num_particles):
        super(IMMloss, self).__init__()
        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon
        self.num_particles = num_particles
        
    def laplace_kernel(self, x, y, w_scale, eps=0.006, dim_normalize=True):
        """
        Laplace kernel: exp(w_scale * max(||x-y||_2, eps)/D)
        
        Args:
            x, y: input tensors
            w_scale: scaling factor (time-dependent)
            eps: small constant to avoid undefined gradients
            dim_normalize: whether to normalize by dimensionality D
        """
        D = x.shape[-1] if dim_normalize else 1.0
        distance = torch.norm(x - y, p=2, dim=-1)
        # Apply max to avoid zero gradients
        distance = torch.clamp(distance, min=eps)
        return torch.exp(-w_scale * distance / D)
    
    def forward(self, model_outputs, time_weights, stop_gradient_outputs):
        """
        Compute the IMM loss for a batch of model outputs.
        
        Args:
            model_outputs: Dictionary containing:
                - ys_t: outputs from time t to s [B, self.pred_horizon, self.obs_horizon]
                - ys_r: outputs from time r to s [B, self.pred_horizon, self.obs_horizon]
                - w_scale: time-dependent scaling factors [B]
            time_weights: w(s,t) weights [B/M]
            stop_gradient_outputs: Optional dictionary with same structure as model_outputs
                                   containing the detached outputs (θ-)
        """
        
        # Extract batch size and reshape for group processing
        batch_size = model_outputs['ys_t'].shape[0]
        M = self.num_particles
        num_groups = batch_size // M
        
        # Flatten pred_horizon and obs_horizon dimensions before reshaping
        # Reshape tensors to [num_groups, M, D]
        ys_t = model_outputs['ys_t'].reshape(batch_size, -1).reshape(num_groups, M, -1)
        ys_r_stop = stop_gradient_outputs['ys_r'].reshape(batch_size, -1).reshape(num_groups, M, -1)
        w_scale = model_outputs['w_scale'].reshape(num_groups, M)
        

        
        # Reshape time weights to [num_groups] by extracting the first element of each group
        time_weights = time_weights.reshape(num_groups, M)[:,0].reshape(-1)
        
        total_loss = 0.0
        for i in range(num_groups):
            group_loss = 0.0
            
            # Compute the kernel matrices
            for j in range(M):
                for k in range(M):
                    # First term: k(f_s,t^θ(x_t^(i,j)), f_s,t^θ(x_t^(i,k)))
                    term1 = self.laplace_kernel(ys_t[i, j], ys_t[i, k], w_scale[i, j])
                    
                    # Second term: k(f_s,r^θ-(x_r^(i,j)), f_s,r^θ-(x_r^(i,k)))
                    term2 = self.laplace_kernel(ys_r_stop[i, j], ys_r_stop[i, k], w_scale[i, j])
                    
                    # Third term: -2k(f_s,t^θ(x_t^(i,j)), f_s,r^θ-(x_r^(i,k)))
                    term3 = -2.0 * self.laplace_kernel(ys_t[i, j], ys_r_stop[i, k], w_scale[i, j])
                    
                    # Sum up the terms
                    group_loss += term1 + term2 + term3
            
            # Apply time-dependent weighting
            group_loss = group_loss * time_weights[i] / (M * M)
            total_loss += group_loss
        
        # Average over the number of groups
        return total_loss / num_groups

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        
        # Second encoder for timestep s (for IMM)
        diffusion_step_encoder_s = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
            
        # Total conditioning dimensions: t embedding + s embedding + global conditioning
        cond_dim = dsed * 2 + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.diffusion_step_encoder_s = diffusion_step_encoder_s
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float],
            timestep_s: Union[torch.Tensor, float],
            global_cond):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step t
        timestep_s: (B,) or int, diffusion step s (for IMM)
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time embedding for t
        # timesteps = timestep
        # if not torch.is_tensor(timesteps):
        #     # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
        #     timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
        #     timesteps = timesteps[None].to(sample.device)
        # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        # timesteps = timesteps.expand(sample.shape[0])
        
        # Get time embedding for t
        t_emb = self.diffusion_step_encoder(timestep)
        
        # 2. time embedding for s
        s_emb = self.diffusion_step_encoder_s(timestep_s)
        
        # Combine t and s embeddings
        global_feature = torch.cat([t_emb, s_emb], dim=-1)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

class RoboIMM:
    """
    Simplified class for IMM with the robotics UNet.
    """
    def __init__(
        self,
        model,
        sigma_data,
        obs_horizon,
        pred_horizon,
        num_particles
    ):
        """
        Initialize the IMM sampler.
        
        Args:
            model: The ConditionalUnet1D model
            sigma_data: Data standard deviation
        """
        self.model = model
        self.sigma_data = sigma_data
        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon
        self.num_particles = num_particles

        self.loss = IMMloss(obs_horizon=obs_horizon, pred_horizon=pred_horizon, num_particles=num_particles)

    def get_alpha_sigma(self, t):
        """Get alpha and sigma values for time t."""
        # Using the "flow matching" schedule
        alpha_t = (1 - t)
        sigma_t = t
        return alpha_t, sigma_t
    
    # def euler_step(self, yt, pred, t, s):
    #     """Euler step for flow matching."""
    #     return yt - (t - s) * self.sigma_data * pred
    
    # def edm_step(self, yt, pred, t, s):
    #     """EDM step for sampling."""
    #     alpha_t, sigma_t = self.get_alpha_sigma(t)
    #     alpha_s, sigma_s = self.get_alpha_sigma(s)
         
    #     c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2)
    #     c_out = -(alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data
        
    #     return c_skip * yt + c_out * pred
    
    def ddim(self, yt, y, s, t):
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        alpha_s, sigma_s = self.get_alpha_sigma(s)

        alpha_s = alpha_s.reshape(-1,1,1)
        sigma_s = sigma_s.reshape(-1,1,1)
        alpha_t = alpha_t.reshape(-1,1,1)
        sigma_t = sigma_t.reshape(-1,1,1)
        
        ys = (alpha_s -   alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt
        return ys
    
    def sample(self, shape, steps=20, global_cond=None, sampling_method="ddim"):
        """
        Generate samples using IMM sampling.
        
        Args:
            shape: Shape of the samples to generate
            steps: Number of sampling steps
            global_cond: Global conditioning
            sampling_method: "ddim"
            
        Returns:
            Generated samples
        """
        device = next(self.model.parameters()).device
        
        # Initialize with noise
        x = torch.randn(shape, device=device) * self.sigma_data
        
        # Define time steps (uniform steps from 1 to 0)
        times = torch.linspace(0.994, 0.006, steps + 1, device=device)
        
        for i in range(steps):
            t = times[i]
            s = times[i + 1]
            
            # Create batched time tensors
            t_batch = torch.full((shape[0],), t, device=device)
            s_batch = torch.full((shape[0],), s, device=device)
            
            # Run model forward
            with torch.no_grad():
                pred = self.predict(x, t_batch, s_batch, global_cond)
            
            # Apply sampling function based on method
            # if sampling_method == "ddim":
            #     x = self.ddim(x, pred, s_batch.view(-1, 1, 1), t_batch.view(-1, 1, 1))
            # else:
            #     raise ValueError(f"Unknown sampling method: {sampling_method}")
        
        return pred

    def calculate_weights(self, s_times, t_times):
        """
        Calculate the time-dependent weighting function w(s,t)
        """
        b = 5  # Hyperparameter from paper
        a = 1    # Hyperparameter from paper (a ∈ {1, 2})

        alpha_t, sigma_t = self.get_alpha_sigma(t_times)
        
        # Calculate log-SNR values
        log_snr_t = 2 * torch.log(alpha_t / sigma_t)
        dlog_snr_t = 2 / (torch.square(t_times) - t_times)
        
        # Calculate coefficient based on equation 13
        sigmoid_term = torch.sigmoid(b - log_snr_t)
        
        snr_term = (alpha_t ** a) / (alpha_t ** 2 + sigma_t ** 2)
        
        return 0.5 * sigmoid_term * -1.0 * dlog_snr_t * snr_term
    
    def predict(self, xt, t, s, obs_cond):
        c = 1000.0
        cskip = 1.0
        cout = -(t-s) * self.sigma_data
        c_timestep = c * t
        c_timestep_s = c * s
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        c_in = (torch.pow(alpha_t, 2) + torch.pow(sigma_t, 2)).rsqrt() / self.sigma_data
        xs = self.model(xt*c_in.reshape(-1,1,1), c_timestep, c_timestep_s, obs_cond)
        return cskip * xt + cout.reshape(-1,1,1) * xs

    def train(self, train_loader, val_loader, num_epochs=100, lr=1e-4, device='cuda'):
        """
        Train the model.
        """

        # Standard ADAM optimizer
        # Note that EMA parametesr are not optimized
        optimizer = torch.optim.AdamW(
            params=self.model.parameters(),
            lr=lr, weight_decay=0)

        # Cosine LR schedule
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

        self.model.train()
        
        with tqdm(range(num_epochs), desc='Epoch') as tglobal:
            # epoch loop
            for epoch_idx in tglobal:
                # batch loop
                with tqdm(train_loader, desc='Batch', leave=False) as tepoch:
                    for batch_idx, batch in enumerate(tepoch):
                        nobs = batch['obs'].to(device)
                        naction = batch['action'].to(device)
                        B = nobs.shape[0]

                        num_groups = B // self.num_particles
                        s_times = torch.rand(num_groups, device=device)
                        t_times = s_times + (1 - s_times) * torch.rand(num_groups, device=device)
                        r_times = s_times + (t_times - s_times) * torch.rand(num_groups, device=device)

                        # times need to be shape (B,), currently they are shape (num_groups,)
                        s_times = s_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)
                        t_times = t_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)
                        r_times = r_times.reshape(-1,1).expand(num_groups, self.num_particles).reshape(-1)

                        noise = torch.randn_like(naction) * self.sigma_data

                        x_t = self.ddim(yt=noise, y=naction, s=t_times, t=torch.ones_like(t_times))
                        x_r = self.ddim(yt=x_t, y=naction, s=r_times, t=t_times)
                                                
                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:self.obs_horizon,:]
                        # (B, obs_horizon * obs_dim)
                        obs_cond = obs_cond.flatten(start_dim=1)

                        optimizer.zero_grad()
                        # pred_grad = self.model(x_t, t_times, s_times, obs_cond)
                        pred_grad = self.predict(x_t, t_times, s_times, obs_cond)

                        with torch.no_grad():
                            # pred_nograd = self.model(x_r, r_times, s_times, obs_cond)
                            pred_nograd = self.predict(x_r, r_times, s_times, obs_cond)
    
                        time_weights = self.calculate_weights(s_times, t_times)
                        
                        # Reshape predictions to match expected dimensions
                        # Assuming pred_grad and pred_nograd are [B, sequence_length, action_dim]
                        # Reshape to [B, self.pred_horizon, self.obs_horizon]
                        model_outputs = {
                            'ys_t': pred_grad.reshape(B, self.pred_horizon, self.obs_horizon),
                            'w_scale': 1.0 / torch.abs((t_times - s_times) * self.sigma_data)
                        }
                        
                        stop_gradient_outputs = {
                            'ys_r': pred_nograd.reshape(B, self.pred_horizon, self.obs_horizon).detach()
                        }
                        
                        loss = self.loss(model_outputs, time_weights, stop_gradient_outputs)
                        
                        loss.backward()
                        optimizer.step()
                        lr_scheduler.step()

                        if batch_idx % 10 == 0:
                            print(f"Epoch {epoch_idx}, Batch {batch_idx}, Loss: {loss.item()}")
                        if batch_idx % 100 == 0 and epoch_idx % 10 == 0:
                            # save model checkpoint
                            torch.save({
                                'epoch': epoch_idx,
                                'model_state_dict': self.model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': loss.item()
                            }, f"ckpts/model_checkpoint_{epoch_idx}_{batch_idx}_{loss.item()}.pth")
                                

In [15]:
#@markdown ### **Network Demo**

# observation and action dimensions corrsponding to
# the output of PushTEnv
obs_dim = 5
action_dim = 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

imm = RoboIMM(
    model=noise_pred_net,
    sigma_data=stats['action_normalized']['std'].mean(),
    obs_horizon=obs_horizon,
    pred_horizon=pred_horizon,
    num_particles=4
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))
diffusion_iter_s = torch.zeros((1,))


r = imm.sample(shape=(1, pred_horizon, action_dim), steps=10, global_cond=obs.flatten(start_dim=1))

# # the noise prediction network
# # takes noisy action, diffusion iteration and observation as input
# # predicts the noise added to action
# noise = noise_pred_net(
#     sample=noised_action,
#     timestep=diffusion_iter,
#     timestep_s=diffusion_iter_s,
#     global_cond=obs.flatten(start_dim=1))

# # illustration of removing noise
# # the actual noise removal is performed by NoiseScheduler
# # and is dependent on the diffusion noise schedule
# denoised_action = noised_action - noise

# # for this demo, we use DDPMScheduler with 100 diffusion iterations
# num_diffusion_iters = 100
# noise_scheduler = DDPMScheduler(
#     num_train_timesteps=num_diffusion_iters,
#     # the choise of beta schedule has big impact on performance
#     # we found squared cosine works the best
#     beta_schedule='squaredcos_cap_v2',
#     # clip output to [-1,1] to improve stability
#     clip_sample=True,
#     # our network predicts noise (instead of denoised action)
#     prediction_type='epsilon'
# )

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

number of parameters: 6.954880e+07


In [12]:
imm.train(dataloader, None, num_epochs=100, lr=1e-4, device='cuda')

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0, Batch 0, Loss: 0.12795916199684143




Epoch 0, Batch 10, Loss: 0.10586139559745789




Epoch 0, Batch 20, Loss: 0.09289903193712234




Epoch 0, Batch 30, Loss: 0.0857173353433609




Epoch 0, Batch 40, Loss: 0.06332328170537949




Epoch 0, Batch 50, Loss: 0.06732785701751709




Epoch 0, Batch 60, Loss: 0.06578713655471802




Epoch 0, Batch 70, Loss: 0.052452586591243744




Epoch 0, Batch 80, Loss: 0.054306551814079285




Epoch 0, Batch 90, Loss: 0.05278688669204712


Epoch:   1%|          | 1/100 [01:21<2:15:08, 81.90s/it]

Epoch 1, Batch 0, Loss: 0.045662663877010345




Epoch 1, Batch 10, Loss: 0.059488408267498016




Epoch 1, Batch 20, Loss: 0.05367956683039665




Epoch 1, Batch 30, Loss: 0.054764293134212494




Epoch 1, Batch 40, Loss: 0.05800040811300278




Epoch 1, Batch 50, Loss: 0.042720019817352295




Epoch 1, Batch 60, Loss: 0.049918971955776215




Epoch 1, Batch 70, Loss: 0.06473232060670853




Epoch 1, Batch 80, Loss: 0.052694015204906464




Epoch 1, Batch 90, Loss: 0.0616900771856308


Epoch:   2%|▏         | 2/100 [02:38<2:08:36, 78.74s/it]

Epoch 2, Batch 0, Loss: 0.05130188167095184




Epoch 2, Batch 10, Loss: 0.04614740237593651




Epoch 2, Batch 20, Loss: 0.040759019553661346




Epoch 2, Batch 30, Loss: 0.04335165023803711




Epoch 2, Batch 40, Loss: 0.057305824011564255




Epoch 2, Batch 50, Loss: 0.04622232913970947




Epoch 2, Batch 60, Loss: 0.05721656605601311




Epoch 2, Batch 70, Loss: 0.04400867223739624




Epoch 2, Batch 80, Loss: 0.039843518286943436




Epoch 2, Batch 90, Loss: 0.03447013348340988


Epoch:   3%|▎         | 3/100 [03:48<2:00:40, 74.65s/it]

Epoch 3, Batch 0, Loss: 0.047640133649110794




Epoch 3, Batch 10, Loss: 0.04685727506875992




Epoch 3, Batch 20, Loss: 0.0504700131714344




Epoch 3, Batch 30, Loss: 0.04217686504125595




Epoch 3, Batch 40, Loss: 0.048061106353998184




Epoch 3, Batch 50, Loss: 0.04259515553712845




Epoch 3, Batch 60, Loss: 0.037044793367385864




Epoch 3, Batch 70, Loss: 0.046266667544841766




Epoch 3, Batch 80, Loss: 0.038989242166280746




Epoch 3, Batch 90, Loss: 0.034498006105422974


Epoch:   4%|▍         | 4/100 [04:54<1:54:23, 71.49s/it]

Epoch 4, Batch 0, Loss: 0.04309607297182083




Epoch 4, Batch 10, Loss: 0.04341019317507744




Epoch 4, Batch 20, Loss: 0.049957383424043655




Epoch 4, Batch 30, Loss: 0.03645448386669159




Epoch 4, Batch 40, Loss: 0.04578027129173279




Epoch 4, Batch 50, Loss: 0.035580217838287354




Epoch 4, Batch 60, Loss: 0.05699731037020683




Epoch 4, Batch 70, Loss: 0.05053478106856346




Epoch 4, Batch 80, Loss: 0.046329300850629807




Epoch 4, Batch 90, Loss: 0.042615026235580444


Epoch:   5%|▌         | 5/100 [06:06<1:53:13, 71.51s/it]

Epoch 5, Batch 0, Loss: 0.04138029366731644




Epoch 5, Batch 10, Loss: 0.057857733219861984




Epoch 5, Batch 20, Loss: 0.031702421605587006




Epoch 5, Batch 30, Loss: 0.04546373337507248




Epoch 5, Batch 40, Loss: 0.04104429855942726




Epoch 5, Batch 50, Loss: 0.04053748771548271




Epoch 5, Batch 60, Loss: 0.04126056656241417




Epoch 5, Batch 70, Loss: 0.03566534444689751




Epoch 5, Batch 80, Loss: 0.039022136479616165




Epoch 5, Batch 90, Loss: 0.03637363761663437


Epoch:   6%|▌         | 6/100 [07:25<1:55:52, 73.96s/it]

Epoch 6, Batch 0, Loss: 0.03832963854074478




Epoch 6, Batch 10, Loss: 0.051607292145490646




Epoch 6, Batch 20, Loss: 0.03726348653435707




Epoch 6, Batch 30, Loss: 0.04049791768193245




Epoch 6, Batch 40, Loss: 0.044263094663619995




Epoch 6, Batch 50, Loss: 0.044757623225450516




Epoch 6, Batch 60, Loss: 0.043422210961580276




Epoch 6, Batch 70, Loss: 0.041713133454322815




Epoch 6, Batch 80, Loss: 0.042056165635585785




Epoch 6, Batch 90, Loss: 0.03544692322611809


Epoch:   7%|▋         | 7/100 [08:34<1:52:26, 72.54s/it]

Epoch 7, Batch 0, Loss: 0.028946777805685997




Epoch 7, Batch 10, Loss: 0.03480581194162369




Epoch 7, Batch 20, Loss: 0.03390072286128998




Epoch 7, Batch 30, Loss: 0.045918937772512436




Epoch 7, Batch 40, Loss: 0.052861861884593964




Epoch 7, Batch 50, Loss: 0.047693077474832535




Epoch 7, Batch 60, Loss: 0.04675474762916565




Epoch 7, Batch 70, Loss: 0.0334862545132637




Epoch 7, Batch 80, Loss: 0.02957005985081196




Epoch 7, Batch 90, Loss: 0.03355449065566063


Epoch:   8%|▊         | 8/100 [09:56<1:55:51, 75.55s/it]

Epoch 8, Batch 0, Loss: 0.04076552763581276




Epoch 8, Batch 10, Loss: 0.042578551918268204




Epoch 8, Batch 20, Loss: 0.038031063973903656




Epoch 8, Batch 30, Loss: 0.04812188819050789




Epoch 8, Batch 40, Loss: 0.03457009792327881




Epoch 8, Batch 50, Loss: 0.03964485973119736




Epoch 8, Batch 60, Loss: 0.04222320020198822




Epoch 8, Batch 70, Loss: 0.05127103254199028




Epoch 8, Batch 80, Loss: 0.0351167730987072




Epoch 8, Batch 90, Loss: 0.035705991089344025


Epoch:   9%|▉         | 9/100 [11:06<1:51:42, 73.66s/it]

Epoch 9, Batch 0, Loss: 0.02739749662578106




Epoch 9, Batch 10, Loss: 0.05660738795995712




Epoch 9, Batch 20, Loss: 0.04038850963115692




Epoch 9, Batch 30, Loss: 0.04669107124209404




Epoch 9, Batch 40, Loss: 0.036677286028862




Epoch 9, Batch 50, Loss: 0.025714442133903503




Epoch 9, Batch 60, Loss: 0.03209073469042778




Epoch 9, Batch 70, Loss: 0.036770351231098175




Epoch 9, Batch 80, Loss: 0.033361371606588364




Epoch 9, Batch 90, Loss: 0.0379304364323616


Epoch:  10%|█         | 10/100 [12:22<1:51:30, 74.33s/it]

Epoch 10, Batch 0, Loss: 0.03877926990389824




Epoch 10, Batch 10, Loss: 0.03859417513012886




Epoch 10, Batch 20, Loss: 0.041698165237903595




Epoch 10, Batch 30, Loss: 0.03474869579076767




Epoch 10, Batch 40, Loss: 0.0313715897500515




Epoch 10, Batch 50, Loss: 0.036272384226322174




Epoch 10, Batch 60, Loss: 0.04732716828584671




Epoch 10, Batch 70, Loss: 0.039849113672971725




Epoch 10, Batch 80, Loss: 0.04006045684218407




Epoch 10, Batch 90, Loss: 0.035033270716667175


Epoch:  11%|█         | 11/100 [13:41<1:52:35, 75.91s/it]

Epoch 11, Batch 0, Loss: 0.03717106580734253




Epoch 11, Batch 10, Loss: 0.028973719105124474




Epoch 11, Batch 20, Loss: 0.03675887733697891




Epoch 11, Batch 30, Loss: 0.05634795129299164




Epoch 11, Batch 40, Loss: 0.032877713441848755




Epoch 11, Batch 50, Loss: 0.03192346543073654




Epoch 11, Batch 60, Loss: 0.04153842478990555




Epoch 11, Batch 70, Loss: 0.031432218849658966




Epoch 11, Batch 80, Loss: 0.03220425918698311




Epoch 11, Batch 90, Loss: 0.03650717809796333


Epoch:  12%|█▏        | 12/100 [14:57<1:51:24, 75.96s/it]

Epoch 12, Batch 0, Loss: 0.03248750790953636




Epoch 12, Batch 10, Loss: 0.03921014443039894




Epoch 12, Batch 20, Loss: 0.03436442092061043




Epoch 12, Batch 30, Loss: 0.040052060037851334




Epoch 12, Batch 40, Loss: 0.04669908061623573




Epoch 12, Batch 50, Loss: 0.04218105971813202




Epoch 12, Batch 60, Loss: 0.03439385071396828




Epoch 12, Batch 70, Loss: 0.041462067514657974




Epoch 12, Batch 80, Loss: 0.03267621248960495




Epoch 12, Batch 90, Loss: 0.03053591586649418


Epoch:  13%|█▎        | 13/100 [15:50<1:40:11, 69.10s/it]

Epoch 13, Batch 0, Loss: 0.03378541022539139




Epoch 13, Batch 10, Loss: 0.031606417149305344




Epoch 13, Batch 20, Loss: 0.03919863700866699




Epoch 13, Batch 30, Loss: 0.04007682576775551




Epoch 13, Batch 40, Loss: 0.029358351603150368




Epoch 13, Batch 50, Loss: 0.032627079635858536




Epoch 13, Batch 60, Loss: 0.03265605494379997




Epoch 13, Batch 70, Loss: 0.03040974587202072




Epoch 13, Batch 80, Loss: 0.031715430319309235




Epoch 13, Batch 90, Loss: 0.026765581220388412


Epoch:  14%|█▍        | 14/100 [16:22<1:22:48, 57.77s/it]

Epoch 14, Batch 0, Loss: 0.03784682974219322




Epoch 14, Batch 10, Loss: 0.0303560309112072




Epoch 14, Batch 20, Loss: 0.03110145963728428




Epoch 14, Batch 30, Loss: 0.03130623698234558




Epoch 14, Batch 40, Loss: 0.030808059498667717




Epoch 14, Batch 50, Loss: 0.03566257655620575




Epoch 14, Batch 60, Loss: 0.03412353992462158




Epoch 14, Batch 70, Loss: 0.032970789819955826




Epoch 14, Batch 80, Loss: 0.035809874534606934




Epoch 14, Batch 90, Loss: 0.03894525766372681


Epoch:  15%|█▌        | 15/100 [16:54<1:10:38, 49.87s/it]

Epoch 15, Batch 0, Loss: 0.031276803463697433




Epoch 15, Batch 10, Loss: 0.03547818213701248




Epoch 15, Batch 20, Loss: 0.03897424042224884




Epoch 15, Batch 30, Loss: 0.0332731157541275




Epoch 15, Batch 40, Loss: 0.028994400054216385




Epoch 15, Batch 50, Loss: 0.03468874841928482




Epoch 15, Batch 60, Loss: 0.028222355991601944




Epoch 15, Batch 70, Loss: 0.02713250182569027




Epoch 15, Batch 80, Loss: 0.03111487813293934




Epoch 15, Batch 90, Loss: 0.028108051046729088


Epoch:  16%|█▌        | 16/100 [17:25<1:02:06, 44.36s/it]

Epoch 16, Batch 0, Loss: 0.02882339246571064




Epoch 16, Batch 10, Loss: 0.023288536816835403




Epoch 16, Batch 20, Loss: 0.03772904351353645




Epoch 16, Batch 30, Loss: 0.02927282452583313




Epoch 16, Batch 40, Loss: 0.04090932384133339




Epoch 16, Batch 50, Loss: 0.03021736443042755




Epoch 16, Batch 60, Loss: 0.038617704063653946




Epoch 16, Batch 70, Loss: 0.03824158012866974




Epoch 16, Batch 80, Loss: 0.026118751615285873




Epoch 16, Batch 90, Loss: 0.03593096137046814


Epoch:  17%|█▋        | 17/100 [17:57<56:02, 40.52s/it]  

Epoch 17, Batch 0, Loss: 0.04393058642745018




Epoch 17, Batch 10, Loss: 0.03620392829179764




Epoch 17, Batch 20, Loss: 0.02840954251587391




Epoch 17, Batch 30, Loss: 0.02754553034901619




Epoch 17, Batch 40, Loss: 0.038040127605199814




Epoch 17, Batch 50, Loss: 0.03156360238790512




Epoch 17, Batch 60, Loss: 0.026863325387239456




Epoch 17, Batch 70, Loss: 0.0296863354742527




Epoch 17, Batch 80, Loss: 0.04090356081724167




Epoch 17, Batch 90, Loss: 0.02736826054751873


Epoch:  18%|█▊        | 18/100 [18:28<51:42, 37.83s/it]

Epoch 18, Batch 0, Loss: 0.036974016577005386




Epoch 18, Batch 10, Loss: 0.028587928041815758




Epoch 18, Batch 20, Loss: 0.04034857451915741




Epoch 18, Batch 30, Loss: 0.03316095471382141




Epoch 18, Batch 40, Loss: 0.02997962385416031




Epoch 18, Batch 50, Loss: 0.03678001090884209




Epoch 18, Batch 60, Loss: 0.039620183408260345




Epoch 18, Batch 70, Loss: 0.039948128163814545




Epoch 18, Batch 80, Loss: 0.02961641363799572




Epoch 18, Batch 90, Loss: 0.04580056294798851


Epoch:  19%|█▉        | 19/100 [19:00<48:32, 35.95s/it]

Epoch 19, Batch 0, Loss: 0.04259585589170456




Epoch 19, Batch 10, Loss: 0.03397706523537636




Epoch 19, Batch 20, Loss: 0.031248798593878746




Epoch 19, Batch 30, Loss: 0.04250355809926987




Epoch 19, Batch 40, Loss: 0.03378823772072792




Epoch 19, Batch 50, Loss: 0.031056473031640053




Epoch 19, Batch 60, Loss: 0.03728770092129707




Epoch 19, Batch 70, Loss: 0.026049556210637093




Epoch 19, Batch 80, Loss: 0.023618126288056374




Epoch 19, Batch 90, Loss: 0.03250256925821304


Epoch:  20%|██        | 20/100 [19:31<46:10, 34.63s/it]

Epoch 20, Batch 0, Loss: 0.029100993648171425




Epoch 20, Batch 10, Loss: 0.02774575911462307




Epoch 20, Batch 20, Loss: 0.02728586085140705




Epoch 20, Batch 30, Loss: 0.02404148317873478




Epoch 20, Batch 40, Loss: 0.03975151106715202




Epoch 20, Batch 50, Loss: 0.02871577814221382




Epoch 20, Batch 60, Loss: 0.026137929409742355




Epoch 20, Batch 70, Loss: 0.0328282006084919




Epoch 20, Batch 80, Loss: 0.026570498943328857




Epoch 20, Batch 90, Loss: 0.02629578299820423


Epoch:  21%|██        | 21/100 [20:04<44:36, 33.88s/it]

Epoch 21, Batch 0, Loss: 0.028832077980041504




Epoch 21, Batch 10, Loss: 0.04044428467750549




Epoch 21, Batch 20, Loss: 0.025728654116392136




Epoch 21, Batch 30, Loss: 0.034973036497831345




Epoch 21, Batch 40, Loss: 0.025023166090250015




Epoch 21, Batch 50, Loss: 0.027498289942741394




Epoch 21, Batch 60, Loss: 0.020908521488308907




Epoch 21, Batch 70, Loss: 0.037648387253284454




Epoch 21, Batch 80, Loss: 0.03419991955161095




Epoch 21, Batch 90, Loss: 0.025532977655529976


Epoch:  22%|██▏       | 22/100 [20:35<43:15, 33.27s/it]

Epoch 22, Batch 0, Loss: 0.01967743970453739




Epoch 22, Batch 10, Loss: 0.029823020100593567




Epoch 22, Batch 20, Loss: 0.03301616013050079




Epoch 22, Batch 30, Loss: 0.023016830906271935




Epoch 22, Batch 40, Loss: 0.025811661034822464




Epoch 22, Batch 50, Loss: 0.02854127436876297




Epoch 22, Batch 60, Loss: 0.027811691164970398




Epoch 22, Batch 70, Loss: 0.027453076094388962




Epoch 22, Batch 80, Loss: 0.03983534872531891




Epoch 22, Batch 90, Loss: 0.03873798996210098


Epoch:  23%|██▎       | 23/100 [21:07<42:09, 32.86s/it]

Epoch 23, Batch 0, Loss: 0.037083107978105545




Epoch 23, Batch 10, Loss: 0.03499959781765938




Epoch 23, Batch 20, Loss: 0.02570919692516327




Epoch 23, Batch 30, Loss: 0.030591724440455437




Epoch 23, Batch 40, Loss: 0.03949444368481636




Epoch 23, Batch 50, Loss: 0.038510214537382126




Epoch 23, Batch 60, Loss: 0.028662940487265587




Epoch 23, Batch 70, Loss: 0.02714165300130844




Epoch 23, Batch 80, Loss: 0.038997817784547806




Epoch 23, Batch 90, Loss: 0.029877200722694397


Epoch:  24%|██▍       | 24/100 [21:39<41:19, 32.62s/it]

Epoch 24, Batch 0, Loss: 0.02480747550725937




Epoch 24, Batch 10, Loss: 0.02778959646821022




Epoch 24, Batch 20, Loss: 0.02986406907439232




Epoch 24, Batch 30, Loss: 0.02438465505838394




Epoch 24, Batch 40, Loss: 0.027068987488746643




Epoch 24, Batch 50, Loss: 0.028301160782575607




Epoch 24, Batch 60, Loss: 0.03433297201991081




Epoch 24, Batch 70, Loss: 0.041134949773550034




Epoch 24, Batch 80, Loss: 0.03131459653377533




Epoch 24, Batch 90, Loss: 0.028186779469251633


Epoch:  25%|██▌       | 25/100 [22:11<40:28, 32.37s/it]

Epoch 25, Batch 0, Loss: 0.03427521511912346




Epoch 25, Batch 10, Loss: 0.025058606639504433




Epoch 25, Batch 20, Loss: 0.02764417603611946




Epoch 25, Batch 30, Loss: 0.029184071347117424




Epoch 25, Batch 40, Loss: 0.03449892997741699




Epoch 25, Batch 50, Loss: 0.031140873208642006




Epoch 25, Batch 60, Loss: 0.03275657072663307




Epoch 25, Batch 70, Loss: 0.03295736387372017




Epoch 25, Batch 80, Loss: 0.03664573282003403




Epoch 25, Batch 90, Loss: 0.029103100299835205


Epoch:  26%|██▌       | 26/100 [22:43<39:43, 32.21s/it]

Epoch 26, Batch 0, Loss: 0.028727838769555092




Epoch 26, Batch 10, Loss: 0.028000516816973686




Epoch 26, Batch 20, Loss: 0.027111267670989037




Epoch 26, Batch 30, Loss: 0.02279485948383808




Epoch 26, Batch 40, Loss: 0.03329811990261078




Epoch 26, Batch 50, Loss: 0.025353817269206047




Epoch 26, Batch 60, Loss: 0.03023718111217022




Epoch 26, Batch 70, Loss: 0.0285684484988451




Epoch 26, Batch 80, Loss: 0.036063309758901596




Epoch 26, Batch 90, Loss: 0.028797363862395287


Epoch:  27%|██▋       | 27/100 [23:15<39:01, 32.07s/it]

Epoch 27, Batch 0, Loss: 0.031905557960271835




Epoch 27, Batch 10, Loss: 0.03405848145484924




Epoch 27, Batch 20, Loss: 0.02313152514398098




Epoch 27, Batch 30, Loss: 0.029294218868017197




Epoch 27, Batch 40, Loss: 0.031937651336193085




Epoch 27, Batch 50, Loss: 0.026596546173095703




Epoch 27, Batch 60, Loss: 0.036433763802051544




Epoch 27, Batch 70, Loss: 0.030561471357941628




Epoch 27, Batch 80, Loss: 0.03125789016485214




Epoch 27, Batch 90, Loss: 0.04191891476511955


Epoch:  28%|██▊       | 28/100 [23:46<38:21, 31.96s/it]

Epoch 28, Batch 0, Loss: 0.0367472767829895




Epoch 28, Batch 10, Loss: 0.027039838954806328




Epoch 28, Batch 20, Loss: 0.028050322085618973




Epoch 28, Batch 30, Loss: 0.021997688338160515




Epoch 28, Batch 40, Loss: 0.024517759680747986




Epoch 28, Batch 50, Loss: 0.02346811443567276




Epoch 28, Batch 60, Loss: 0.027873700484633446




Epoch 28, Batch 70, Loss: 0.02344086393713951




Epoch 28, Batch 80, Loss: 0.026739539578557014




Epoch 28, Batch 90, Loss: 0.02500600554049015


Epoch:  29%|██▉       | 29/100 [24:19<38:03, 32.17s/it]

Epoch 29, Batch 0, Loss: 0.022737732157111168




Epoch 29, Batch 10, Loss: 0.026058293879032135




Epoch 29, Batch 20, Loss: 0.035137616097927094




Epoch 29, Batch 30, Loss: 0.03457065299153328




Epoch 29, Batch 40, Loss: 0.027251217514276505




Epoch 29, Batch 50, Loss: 0.03716740384697914




Epoch 29, Batch 60, Loss: 0.022683877497911453




Epoch 29, Batch 70, Loss: 0.031617723405361176




Epoch 29, Batch 80, Loss: 0.03376564756035805




Epoch 29, Batch 90, Loss: 0.02149696834385395


Epoch:  30%|███       | 30/100 [24:51<37:27, 32.11s/it]

Epoch 30, Batch 0, Loss: 0.024471884593367577




Epoch 30, Batch 10, Loss: 0.03248615562915802




Epoch 30, Batch 20, Loss: 0.030838284641504288




Epoch 30, Batch 30, Loss: 0.021874282509088516




Epoch 30, Batch 40, Loss: 0.02589605748653412




Epoch 30, Batch 50, Loss: 0.03258023038506508




Epoch 30, Batch 60, Loss: 0.027065986767411232




Epoch 30, Batch 70, Loss: 0.020743535831570625




Epoch 30, Batch 80, Loss: 0.03214326500892639




Epoch 30, Batch 90, Loss: 0.022563014179468155


Epoch:  31%|███       | 31/100 [25:24<37:04, 32.25s/it]

Epoch 31, Batch 0, Loss: 0.03131671994924545




Epoch 31, Batch 10, Loss: 0.024846332147717476




Epoch 31, Batch 20, Loss: 0.021233506500720978




Epoch 31, Batch 30, Loss: 0.02700771763920784




Epoch 31, Batch 40, Loss: 0.026156798005104065




Epoch 31, Batch 50, Loss: 0.028566190972924232




Epoch 31, Batch 60, Loss: 0.025780396535992622




Epoch 31, Batch 70, Loss: 0.03220392391085625




Epoch 31, Batch 80, Loss: 0.03283816576004028




Epoch 31, Batch 90, Loss: 0.03253334015607834


Epoch:  32%|███▏      | 32/100 [25:55<36:20, 32.06s/it]

Epoch 32, Batch 0, Loss: 0.02647382766008377




Epoch 32, Batch 10, Loss: 0.023961396887898445




Epoch 32, Batch 20, Loss: 0.02538124844431877




Epoch 32, Batch 30, Loss: 0.029889238998293877




Epoch 32, Batch 40, Loss: 0.03023425117135048




Epoch 32, Batch 50, Loss: 0.024716628715395927




Epoch 32, Batch 60, Loss: 0.026322221383452415




Epoch 32, Batch 70, Loss: 0.01834980957210064




Epoch 32, Batch 80, Loss: 0.024759763851761818




Epoch 32, Batch 90, Loss: 0.03294280916452408


Epoch:  33%|███▎      | 33/100 [26:27<35:46, 32.03s/it]

Epoch 33, Batch 0, Loss: 0.019997987896203995




Epoch 33, Batch 10, Loss: 0.01993803307414055




Epoch 33, Batch 20, Loss: 0.02242448925971985




Epoch 33, Batch 30, Loss: 0.02847692370414734




Epoch 33, Batch 40, Loss: 0.025888213887810707




Epoch 33, Batch 50, Loss: 0.03098474070429802




Epoch 33, Batch 60, Loss: 0.043028172105550766




Epoch 33, Batch 70, Loss: 0.03248177468776703




Epoch 33, Batch 80, Loss: 0.027478137984871864




Epoch 33, Batch 90, Loss: 0.036890909075737


Epoch:  34%|███▍      | 34/100 [27:00<35:37, 32.39s/it]

Epoch 34, Batch 0, Loss: 0.04063580185174942




Epoch 34, Batch 10, Loss: 0.03655150160193443




Epoch 34, Batch 20, Loss: 0.028021028265357018




Epoch 34, Batch 30, Loss: 0.032635100185871124




Epoch 34, Batch 40, Loss: 0.021681709215044975




Epoch 34, Batch 50, Loss: 0.022536007687449455




Epoch 34, Batch 60, Loss: 0.019998518750071526




Epoch 34, Batch 70, Loss: 0.026888685300946236




Epoch 34, Batch 80, Loss: 0.023066962137818336




Epoch 34, Batch 90, Loss: 0.02257240191102028


Epoch:  35%|███▌      | 35/100 [27:32<34:53, 32.20s/it]

Epoch 35, Batch 0, Loss: 0.01800733059644699




Epoch 35, Batch 10, Loss: 0.02711399830877781




Epoch 35, Batch 20, Loss: 0.035023052245378494




Epoch 35, Batch 30, Loss: 0.02796507254242897




Epoch 35, Batch 40, Loss: 0.025250360369682312




Epoch 35, Batch 50, Loss: 0.02282523922622204




Epoch 35, Batch 60, Loss: 0.03354493901133537




Epoch 35, Batch 70, Loss: 0.028041420504450798




Epoch 35, Batch 80, Loss: 0.026281895115971565




Epoch 35, Batch 90, Loss: 0.031381767243146896


Epoch:  36%|███▌      | 36/100 [28:04<34:10, 32.05s/it]

Epoch 36, Batch 0, Loss: 0.03640943020582199




Epoch 36, Batch 10, Loss: 0.0359857901930809




Epoch 36, Batch 20, Loss: 0.02146206982433796




Epoch 36, Batch 30, Loss: 0.02805524505674839




Epoch 36, Batch 40, Loss: 0.0268420297652483




Epoch 36, Batch 50, Loss: 0.023012181743979454




Epoch 36, Batch 60, Loss: 0.027680961415171623




Epoch 36, Batch 70, Loss: 0.03015732392668724




Epoch 36, Batch 80, Loss: 0.02470102533698082




Epoch 36, Batch 90, Loss: 0.022464711219072342


Epoch:  37%|███▋      | 37/100 [28:36<33:34, 31.97s/it]

Epoch 37, Batch 0, Loss: 0.024477537721395493




Epoch 37, Batch 10, Loss: 0.03324523940682411




Epoch 37, Batch 20, Loss: 0.025421274825930595




Epoch 37, Batch 30, Loss: 0.028023315593600273




Epoch 37, Batch 40, Loss: 0.027125321328639984




Epoch 37, Batch 50, Loss: 0.02445741556584835




Epoch 37, Batch 60, Loss: 0.022922633215785027




Epoch 37, Batch 70, Loss: 0.027931824326515198




Epoch 37, Batch 80, Loss: 0.026031214743852615




Epoch 37, Batch 90, Loss: 0.029551541432738304


Epoch:  38%|███▊      | 38/100 [29:08<33:03, 31.99s/it]

Epoch 38, Batch 0, Loss: 0.03677930682897568




Epoch 38, Batch 10, Loss: 0.025210173800587654




Epoch 38, Batch 20, Loss: 0.026648717001080513




Epoch 38, Batch 30, Loss: 0.02590513974428177




Epoch 38, Batch 40, Loss: 0.02524406835436821




Epoch 38, Batch 50, Loss: 0.020200956612825394




Epoch 38, Batch 60, Loss: 0.03086705133318901




Epoch 38, Batch 70, Loss: 0.023568978533148766




Epoch 38, Batch 80, Loss: 0.01992260105907917




Epoch 38, Batch 90, Loss: 0.026054492220282555


Epoch:  39%|███▉      | 39/100 [29:39<32:23, 31.87s/it]

Epoch 39, Batch 0, Loss: 0.01832401379942894




Epoch 39, Batch 10, Loss: 0.02417115494608879




Epoch 39, Batch 20, Loss: 0.020278124138712883




Epoch 39, Batch 30, Loss: 0.024983488023281097




Epoch 39, Batch 40, Loss: 0.017200468108057976




Epoch 39, Batch 50, Loss: 0.019344309344887733




Epoch 39, Batch 60, Loss: 0.0192120224237442




Epoch 39, Batch 70, Loss: 0.020146658644080162




Epoch 39, Batch 80, Loss: 0.030852951109409332




Epoch 39, Batch 90, Loss: 0.020641669631004333


Epoch:  40%|████      | 40/100 [30:11<31:47, 31.80s/it]

Epoch 40, Batch 0, Loss: 0.022166145965456963




Epoch 40, Batch 10, Loss: 0.027483444660902023




Epoch 40, Batch 20, Loss: 0.03977104276418686




Epoch 40, Batch 30, Loss: 0.017279617488384247




Epoch 40, Batch 40, Loss: 0.028168048709630966




Epoch 40, Batch 50, Loss: 0.024773728102445602




Epoch 40, Batch 60, Loss: 0.024632025510072708




Epoch 40, Batch 70, Loss: 0.021763132885098457




Epoch 40, Batch 80, Loss: 0.020033344626426697




Epoch 40, Batch 90, Loss: 0.022884763777256012


Epoch:  41%|████      | 41/100 [30:43<31:26, 31.98s/it]

Epoch 41, Batch 0, Loss: 0.023324504494667053




Epoch 41, Batch 10, Loss: 0.022558903321623802




Epoch 41, Batch 20, Loss: 0.0238756462931633




Epoch 41, Batch 30, Loss: 0.02008354663848877




Epoch 41, Batch 40, Loss: 0.0213757511228323




Epoch 41, Batch 50, Loss: 0.01990528777241707




Epoch 41, Batch 60, Loss: 0.020652908831834793




Epoch 41, Batch 70, Loss: 0.022761015221476555




Epoch 41, Batch 80, Loss: 0.028484389185905457




Epoch 41, Batch 90, Loss: 0.021370533853769302


Epoch:  42%|████▏     | 42/100 [31:16<30:57, 32.03s/it]

Epoch 42, Batch 0, Loss: 0.021847642958164215




Epoch 42, Batch 10, Loss: 0.033535804599523544




Epoch 42, Batch 20, Loss: 0.03958481550216675




Epoch 42, Batch 30, Loss: 0.024444198235869408




Epoch 42, Batch 40, Loss: 0.027661295607686043




Epoch 42, Batch 50, Loss: 0.028084363788366318




Epoch 42, Batch 60, Loss: 0.022388577461242676




Epoch 42, Batch 70, Loss: 0.018202506005764008




Epoch 42, Batch 80, Loss: 0.026252716779708862




Epoch 42, Batch 90, Loss: 0.01976301707327366


Epoch:  43%|████▎     | 43/100 [31:47<30:20, 31.93s/it]

Epoch 43, Batch 0, Loss: 0.038033001124858856




Epoch 43, Batch 10, Loss: 0.028685439378023148




Epoch 43, Batch 20, Loss: 0.019590318202972412




Epoch 43, Batch 30, Loss: 0.02113347128033638




Epoch 43, Batch 40, Loss: 0.01802600547671318




Epoch 43, Batch 50, Loss: 0.024220887571573257




Epoch 43, Batch 60, Loss: 0.029537934809923172




Epoch 43, Batch 70, Loss: 0.021669255569577217




Epoch 43, Batch 80, Loss: 0.03303299844264984




Epoch 43, Batch 90, Loss: 0.027660207822918892


Epoch:  44%|████▍     | 44/100 [32:20<30:00, 32.14s/it]

Epoch 44, Batch 0, Loss: 0.029248785227537155




Epoch 44, Batch 10, Loss: 0.033349789679050446




Epoch 44, Batch 20, Loss: 0.01895267516374588




Epoch 44, Batch 30, Loss: 0.02824518084526062




Epoch 44, Batch 40, Loss: 0.02365107648074627




Epoch 44, Batch 50, Loss: 0.025712639093399048




Epoch 44, Batch 60, Loss: 0.02557014860212803




Epoch 44, Batch 70, Loss: 0.026530688628554344




Epoch 44, Batch 80, Loss: 0.029383234679698944




Epoch 44, Batch 90, Loss: 0.027155211195349693


Epoch:  45%|████▌     | 45/100 [32:53<29:38, 32.34s/it]

Epoch 45, Batch 0, Loss: 0.02241477370262146




Epoch 45, Batch 10, Loss: 0.022546149790287018




Epoch 45, Batch 20, Loss: 0.01563180610537529




Epoch 45, Batch 30, Loss: 0.03130606934428215




Epoch 45, Batch 40, Loss: 0.02462904527783394




Epoch 45, Batch 50, Loss: 0.026071229949593544




Epoch 45, Batch 60, Loss: 0.02098182588815689




Epoch 45, Batch 70, Loss: 0.01949957199394703




Epoch 45, Batch 80, Loss: 0.018587103113532066




Epoch 45, Batch 90, Loss: 0.02350643463432789


Epoch:  46%|████▌     | 46/100 [33:24<28:53, 32.10s/it]

Epoch 46, Batch 0, Loss: 0.021525021642446518




Epoch 46, Batch 10, Loss: 0.026002680882811546




Epoch 46, Batch 20, Loss: 0.022758295759558678




Epoch 46, Batch 30, Loss: 0.022074637934565544




Epoch 46, Batch 40, Loss: 0.018888600170612335




Epoch 46, Batch 50, Loss: 0.020092129707336426




Epoch 46, Batch 60, Loss: 0.02491483837366104




Epoch 46, Batch 70, Loss: 0.029244549572467804




Epoch 46, Batch 80, Loss: 0.02344808541238308




Epoch 46, Batch 90, Loss: 0.023427104577422142


Epoch:  47%|████▋     | 47/100 [33:56<28:12, 31.94s/it]

Epoch 47, Batch 0, Loss: 0.018167883157730103




Epoch 47, Batch 10, Loss: 0.02322203665971756




Epoch 47, Batch 20, Loss: 0.018244557082653046




Epoch 47, Batch 30, Loss: 0.014666343107819557




Epoch 47, Batch 40, Loss: 0.019624769687652588




Epoch 47, Batch 50, Loss: 0.022095421329140663




Epoch 47, Batch 60, Loss: 0.025483334437012672




Epoch 47, Batch 70, Loss: 0.021784598007798195




Epoch 47, Batch 80, Loss: 0.021299511194229126




Epoch 47, Batch 90, Loss: 0.01664435863494873


Epoch:  48%|████▊     | 48/100 [34:27<27:34, 31.82s/it]

Epoch 48, Batch 0, Loss: 0.028778310865163803




Epoch 48, Batch 10, Loss: 0.022588588297367096




Epoch 48, Batch 20, Loss: 0.027500730007886887




Epoch 48, Batch 30, Loss: 0.03091050311923027




Epoch 48, Batch 40, Loss: 0.029134973883628845




Epoch 48, Batch 50, Loss: 0.03739570453763008




Epoch 48, Batch 60, Loss: 0.03052104078233242




Epoch 48, Batch 70, Loss: 0.026409756392240524




Epoch 48, Batch 80, Loss: 0.024474048987030983




Epoch 48, Batch 90, Loss: 0.026990894228219986


Epoch:  49%|████▉     | 49/100 [34:59<26:59, 31.75s/it]

Epoch 49, Batch 0, Loss: 0.01976710371673107




Epoch 49, Batch 10, Loss: 0.026107434183359146




Epoch 49, Batch 20, Loss: 0.02228998951613903




Epoch 49, Batch 30, Loss: 0.021487142890691757




Epoch 49, Batch 40, Loss: 0.01668013259768486




Epoch 49, Batch 50, Loss: 0.020579230040311813




Epoch 49, Batch 60, Loss: 0.023626791313290596




Epoch 49, Batch 70, Loss: 0.018685994669795036




Epoch 49, Batch 80, Loss: 0.01893521472811699




Epoch 49, Batch 90, Loss: 0.01853223703801632


Epoch:  50%|█████     | 50/100 [35:31<26:31, 31.83s/it]

Epoch 50, Batch 0, Loss: 0.021056536585092545




Epoch 50, Batch 10, Loss: 0.018268801271915436




Epoch 50, Batch 20, Loss: 0.019802050665020943




Epoch 50, Batch 30, Loss: 0.021605124697089195




Epoch 50, Batch 40, Loss: 0.02240731194615364




Epoch 50, Batch 50, Loss: 0.015710610896348953




Epoch 50, Batch 60, Loss: 0.024659916758537292




Epoch 50, Batch 70, Loss: 0.02379925362765789




Epoch 50, Batch 80, Loss: 0.028617922216653824




Epoch 50, Batch 90, Loss: 0.019896110519766808


Epoch:  51%|█████     | 51/100 [36:03<26:07, 31.99s/it]

Epoch 51, Batch 0, Loss: 0.0225357748568058




Epoch 51, Batch 10, Loss: 0.029353519901633263




Epoch 51, Batch 20, Loss: 0.022754278033971786




Epoch 51, Batch 30, Loss: 0.025709059089422226




Epoch 51, Batch 40, Loss: 0.016530049964785576




Epoch 51, Batch 50, Loss: 0.028302952647209167




Epoch 51, Batch 60, Loss: 0.022249910980463028




Epoch 51, Batch 70, Loss: 0.013305925764143467




Epoch 51, Batch 80, Loss: 0.022901417687535286




Epoch 51, Batch 90, Loss: 0.03127800300717354


Epoch:  52%|█████▏    | 52/100 [36:35<25:30, 31.88s/it]

Epoch 52, Batch 0, Loss: 0.024841489270329475




Epoch 52, Batch 10, Loss: 0.017937645316123962




Epoch 52, Batch 20, Loss: 0.013535819016397




Epoch 52, Batch 30, Loss: 0.017724119126796722




Epoch 52, Batch 40, Loss: 0.020305989310145378




Epoch 52, Batch 50, Loss: 0.020958412438631058




Epoch 52, Batch 60, Loss: 0.02118454873561859




Epoch 52, Batch 70, Loss: 0.03249035030603409




Epoch 52, Batch 80, Loss: 0.029297592118382454




Epoch 52, Batch 90, Loss: 0.028009070083498955


Epoch:  53%|█████▎    | 53/100 [37:06<24:54, 31.79s/it]

Epoch 53, Batch 0, Loss: 0.025159098207950592




Epoch 53, Batch 10, Loss: 0.027154186740517616




Epoch 53, Batch 20, Loss: 0.021695706993341446




Epoch 53, Batch 30, Loss: 0.02606530487537384




Epoch 53, Batch 40, Loss: 0.016487648710608482




Epoch 53, Batch 50, Loss: 0.014670302160084248




Epoch 53, Batch 60, Loss: 0.015393472276628017




Epoch 53, Batch 70, Loss: 0.014578363858163357




Epoch 53, Batch 80, Loss: 0.02739613875746727




Epoch 53, Batch 90, Loss: 0.019784031435847282


Epoch:  54%|█████▍    | 54/100 [37:38<24:19, 31.74s/it]

Epoch 54, Batch 0, Loss: 0.022858232259750366




Epoch 54, Batch 10, Loss: 0.01675352267920971




Epoch 54, Batch 20, Loss: 0.019953353330492973




Epoch 54, Batch 30, Loss: 0.016885992139577866




Epoch 54, Batch 40, Loss: 0.020214462652802467




Epoch 54, Batch 50, Loss: 0.02273779921233654




Epoch 54, Batch 60, Loss: 0.021713677793741226




Epoch 54, Batch 70, Loss: 0.020005064085125923




Epoch 54, Batch 80, Loss: 0.0196173544973135




Epoch 54, Batch 90, Loss: 0.027519984170794487


Epoch:  55%|█████▌    | 55/100 [38:10<23:45, 31.68s/it]

Epoch 55, Batch 0, Loss: 0.02745344676077366




Epoch 55, Batch 10, Loss: 0.023231489583849907




Epoch 55, Batch 20, Loss: 0.02851390466094017




Epoch 55, Batch 30, Loss: 0.016434723511338234




Epoch 55, Batch 40, Loss: 0.02103094756603241




Epoch 55, Batch 50, Loss: 0.017116624861955643




Epoch 55, Batch 60, Loss: 0.01735779456794262




Epoch 55, Batch 70, Loss: 0.02222573570907116




Epoch 55, Batch 80, Loss: 0.017300913110375404




Epoch 55, Batch 90, Loss: 0.016390429809689522


Epoch:  56%|█████▌    | 56/100 [38:41<23:12, 31.65s/it]

Epoch 56, Batch 0, Loss: 0.01885402947664261




Epoch 56, Batch 10, Loss: 0.01866471953690052




Epoch 56, Batch 20, Loss: 0.022152790799736977




Epoch 56, Batch 30, Loss: 0.021157599985599518




Epoch 56, Batch 40, Loss: 0.01646558754146099




Epoch 56, Batch 50, Loss: 0.020998338237404823




Epoch 56, Batch 60, Loss: 0.022638050839304924




Epoch 56, Batch 70, Loss: 0.019917169585824013




Epoch 56, Batch 80, Loss: 0.029534339904785156




Epoch 56, Batch 90, Loss: 0.02550332248210907


Epoch:  57%|█████▋    | 57/100 [39:13<22:39, 31.61s/it]

Epoch 57, Batch 0, Loss: 0.02572874166071415




Epoch 57, Batch 10, Loss: 0.021358672529459




Epoch 57, Batch 20, Loss: 0.02291577309370041




Epoch 57, Batch 30, Loss: 0.020988985896110535




Epoch 57, Batch 40, Loss: 0.02262684516608715




Epoch 57, Batch 50, Loss: 0.02469712868332863




Epoch 57, Batch 60, Loss: 0.023035505786538124




Epoch 57, Batch 70, Loss: 0.018668081611394882




Epoch 57, Batch 80, Loss: 0.021707821637392044




Epoch 57, Batch 90, Loss: 0.022633671760559082


Epoch:  58%|█████▊    | 58/100 [39:44<22:07, 31.60s/it]

Epoch 58, Batch 0, Loss: 0.021412717178463936




Epoch 58, Batch 10, Loss: 0.01604374684393406




Epoch 58, Batch 20, Loss: 0.017276853322982788




Epoch 58, Batch 30, Loss: 0.023860570043325424




Epoch 58, Batch 40, Loss: 0.017272919416427612




Epoch 58, Batch 50, Loss: 0.02111137844622135




Epoch 58, Batch 60, Loss: 0.0199496541172266




Epoch 58, Batch 70, Loss: 0.028036532923579216




Epoch 58, Batch 80, Loss: 0.02653992548584938




Epoch 58, Batch 90, Loss: 0.02967575564980507


Epoch:  59%|█████▉    | 59/100 [40:16<21:35, 31.59s/it]

Epoch 59, Batch 0, Loss: 0.029074041172862053




Epoch 59, Batch 10, Loss: 0.019694337621331215




Epoch 59, Batch 20, Loss: 0.023725096136331558




Epoch 59, Batch 30, Loss: 0.032805703580379486




Epoch 59, Batch 40, Loss: 0.025719985365867615




Epoch 59, Batch 50, Loss: 0.0215693861246109




Epoch 59, Batch 60, Loss: 0.023275816813111305




Epoch 59, Batch 70, Loss: 0.01571378856897354




Epoch 59, Batch 80, Loss: 0.024044422432780266




Epoch 59, Batch 90, Loss: 0.019068032503128052


Epoch:  60%|██████    | 60/100 [40:48<21:04, 31.61s/it]

Epoch 60, Batch 0, Loss: 0.018417851999402046




Epoch 60, Batch 10, Loss: 0.015357470139861107




Epoch 60, Batch 20, Loss: 0.018230542540550232




Epoch 60, Batch 30, Loss: 0.017242776229977608




Epoch 60, Batch 40, Loss: 0.019394218921661377




Epoch 60, Batch 50, Loss: 0.02049923874437809




Epoch 60, Batch 60, Loss: 0.020856164395809174




Epoch 60, Batch 70, Loss: 0.02225934900343418




Epoch 60, Batch 80, Loss: 0.024252068251371384




Epoch 60, Batch 90, Loss: 0.021394815295934677


Epoch:  61%|██████    | 61/100 [41:20<20:44, 31.90s/it]

Epoch 61, Batch 0, Loss: 0.026319893077015877




Epoch 61, Batch 10, Loss: 0.024920552968978882




Epoch 61, Batch 20, Loss: 0.022020716220140457




Epoch 61, Batch 30, Loss: 0.018604321405291557




Epoch 61, Batch 40, Loss: 0.028659716248512268




Epoch 61, Batch 50, Loss: 0.025468425825238228




Epoch 61, Batch 60, Loss: 0.022332316264510155




Epoch 61, Batch 70, Loss: 0.02821270562708378




Epoch 61, Batch 80, Loss: 0.01958215981721878




Epoch 61, Batch 90, Loss: 0.023303795605897903


Epoch:  62%|██████▏   | 62/100 [41:52<20:08, 31.80s/it]

Epoch 62, Batch 0, Loss: 0.017724458128213882




Epoch 62, Batch 10, Loss: 0.016322098672389984




Epoch 62, Batch 20, Loss: 0.029639575630426407




Epoch 62, Batch 30, Loss: 0.02017042599618435




Epoch 62, Batch 40, Loss: 0.017133910208940506




Epoch 62, Batch 50, Loss: 0.018471838906407356




Epoch 62, Batch 60, Loss: 0.018962090834975243




Epoch 62, Batch 70, Loss: 0.02182014100253582




Epoch 62, Batch 80, Loss: 0.022140594199299812




Epoch 62, Batch 90, Loss: 0.024463282898068428


Epoch:  63%|██████▎   | 63/100 [42:23<19:35, 31.76s/it]

Epoch 63, Batch 0, Loss: 0.02015802264213562




Epoch 63, Batch 10, Loss: 0.02465180866420269




Epoch 63, Batch 20, Loss: 0.023178428411483765




Epoch 63, Batch 30, Loss: 0.025969702750444412




Epoch 63, Batch 40, Loss: 0.026584960520267487




Epoch 63, Batch 50, Loss: 0.021509602665901184




Epoch 63, Batch 60, Loss: 0.018926352262496948




Epoch 63, Batch 70, Loss: 0.028389822691679




Epoch 63, Batch 80, Loss: 0.021333787590265274




Epoch 63, Batch 90, Loss: 0.03204982727766037


Epoch:  64%|██████▍   | 64/100 [42:55<19:01, 31.71s/it]

Epoch 64, Batch 0, Loss: 0.012262246571481228




Epoch 64, Batch 10, Loss: 0.01960628293454647




Epoch 64, Batch 20, Loss: 0.01624423637986183




Epoch 64, Batch 30, Loss: 0.019762730225920677




Epoch 64, Batch 40, Loss: 0.02009209804236889




Epoch 64, Batch 50, Loss: 0.0171248447149992




Epoch 64, Batch 60, Loss: 0.024406328797340393




Epoch 64, Batch 70, Loss: 0.025829080492258072




Epoch 64, Batch 80, Loss: 0.026752231642603874




Epoch 64, Batch 90, Loss: 0.022799061611294746


Epoch:  65%|██████▌   | 65/100 [43:27<18:29, 31.71s/it]

Epoch 65, Batch 0, Loss: 0.019260281696915627




Epoch 65, Batch 10, Loss: 0.02023685723543167




Epoch 65, Batch 20, Loss: 0.023122908547520638




Epoch 65, Batch 30, Loss: 0.03439376503229141




Epoch 65, Batch 40, Loss: 0.02591182477772236




Epoch 65, Batch 50, Loss: 0.02385399304330349




Epoch 65, Batch 60, Loss: 0.019335946068167686




Epoch 65, Batch 70, Loss: 0.017865663394331932




Epoch 65, Batch 80, Loss: 0.024726413190364838




Epoch 65, Batch 90, Loss: 0.026337850838899612


Epoch:  66%|██████▌   | 66/100 [43:59<18:00, 31.77s/it]

Epoch 66, Batch 0, Loss: 0.022421790286898613




Epoch 66, Batch 10, Loss: 0.017304450273513794




Epoch 66, Batch 20, Loss: 0.01714247465133667




Epoch 66, Batch 30, Loss: 0.01987343467772007




Epoch 66, Batch 40, Loss: 0.016817500814795494




Epoch 66, Batch 50, Loss: 0.016158875077962875




Epoch 66, Batch 60, Loss: 0.021046698093414307




Epoch 66, Batch 70, Loss: 0.02533118985593319




Epoch 66, Batch 80, Loss: 0.02239816263318062




Epoch 66, Batch 90, Loss: 0.018841661512851715


Epoch:  67%|██████▋   | 67/100 [44:31<17:31, 31.85s/it]

Epoch 67, Batch 0, Loss: 0.022073350846767426




Epoch 67, Batch 10, Loss: 0.01995684951543808




Epoch 67, Batch 20, Loss: 0.024471594020724297




Epoch 67, Batch 30, Loss: 0.032783377915620804




Epoch 67, Batch 40, Loss: 0.019954565912485123




Epoch 67, Batch 50, Loss: 0.023910505697131157




Epoch 67, Batch 60, Loss: 0.0205612163990736




Epoch 67, Batch 70, Loss: 0.016041530296206474




Epoch 67, Batch 80, Loss: 0.024666668847203255




Epoch 67, Batch 90, Loss: 0.02071608230471611


Epoch:  68%|██████▊   | 68/100 [45:03<17:01, 31.91s/it]

Epoch 68, Batch 0, Loss: 0.024413924664258957




Epoch 68, Batch 10, Loss: 0.01965872384607792




Epoch 68, Batch 20, Loss: 0.018760303035378456




Epoch 68, Batch 30, Loss: 0.019840266555547714




Epoch 68, Batch 40, Loss: 0.023874277248978615




Epoch 68, Batch 50, Loss: 0.01594240963459015




Epoch 68, Batch 60, Loss: 0.022791368886828423




Epoch 68, Batch 70, Loss: 0.013466528616845608




Epoch 68, Batch 80, Loss: 0.02225419320166111




Epoch 68, Batch 90, Loss: 0.01815839298069477


Epoch:  69%|██████▉   | 69/100 [45:34<16:27, 31.84s/it]

Epoch 69, Batch 0, Loss: 0.02413945458829403




Epoch 69, Batch 10, Loss: 0.024744678288698196




Epoch 69, Batch 20, Loss: 0.029313325881958008




Epoch 69, Batch 30, Loss: 0.021583324298262596




Epoch 69, Batch 40, Loss: 0.02047988772392273




Epoch 69, Batch 50, Loss: 0.022782323881983757




Epoch 69, Batch 60, Loss: 0.02350659854710102




Epoch 69, Batch 70, Loss: 0.026236075907945633




Epoch 69, Batch 80, Loss: 0.021562673151493073




Epoch 69, Batch 90, Loss: 0.026843484491109848


Epoch:  70%|███████   | 70/100 [46:06<15:53, 31.79s/it]

Epoch 70, Batch 0, Loss: 0.019123878329992294




Epoch 70, Batch 10, Loss: 0.020817166194319725




Epoch 70, Batch 20, Loss: 0.029967134818434715




Epoch 70, Batch 30, Loss: 0.019027426838874817




Epoch 70, Batch 40, Loss: 0.01771349459886551




Epoch 70, Batch 50, Loss: 0.010586995631456375




Epoch 70, Batch 60, Loss: 0.015160455368459225




Epoch 70, Batch 70, Loss: 0.02298150211572647




Epoch 70, Batch 80, Loss: 0.01817169599235058




Epoch 70, Batch 90, Loss: 0.017988013103604317


Epoch:  71%|███████   | 71/100 [46:39<15:34, 32.22s/it]

Epoch 71, Batch 0, Loss: 0.01673363335430622




Epoch 71, Batch 10, Loss: 0.02031085081398487




Epoch 71, Batch 20, Loss: 0.016530927270650864




Epoch 71, Batch 30, Loss: 0.016045859083533287




Epoch 71, Batch 40, Loss: 0.023034047335386276




Epoch 71, Batch 50, Loss: 0.022006088867783546




Epoch 71, Batch 60, Loss: 0.02107849344611168




Epoch 71, Batch 70, Loss: 0.035084910690784454




Epoch 71, Batch 80, Loss: 0.02511434629559517




Epoch 71, Batch 90, Loss: 0.019896874204277992


Epoch:  72%|███████▏  | 72/100 [47:12<15:10, 32.51s/it]

Epoch 72, Batch 0, Loss: 0.01577662117779255




Epoch 72, Batch 10, Loss: 0.02477842941880226




Epoch 72, Batch 20, Loss: 0.019854510203003883




Epoch 72, Batch 30, Loss: 0.021434156224131584




Epoch 72, Batch 40, Loss: 0.016099445521831512




Epoch 72, Batch 50, Loss: 0.025068392977118492




Epoch 72, Batch 60, Loss: 0.01805376634001732




Epoch 72, Batch 70, Loss: 0.015077042393386364




Epoch 72, Batch 80, Loss: 0.019607767462730408




Epoch 72, Batch 90, Loss: 0.021442735567688942


Epoch:  73%|███████▎  | 73/100 [47:44<14:31, 32.29s/it]

Epoch 73, Batch 0, Loss: 0.021660974249243736




Epoch 73, Batch 10, Loss: 0.01996924914419651




Epoch 73, Batch 20, Loss: 0.02142460271716118




Epoch 73, Batch 30, Loss: 0.024502994492650032




Epoch 73, Batch 40, Loss: 0.03148839622735977




Epoch 73, Batch 50, Loss: 0.015144022181630135




Epoch 73, Batch 60, Loss: 0.012427919544279575




Epoch 73, Batch 70, Loss: 0.01955055072903633




Epoch 73, Batch 80, Loss: 0.02969977632164955




Epoch 73, Batch 90, Loss: 0.020327823236584663


Epoch:  74%|███████▍  | 74/100 [48:17<14:03, 32.45s/it]

Epoch 74, Batch 0, Loss: 0.02528340183198452




Epoch 74, Batch 10, Loss: 0.026799596846103668




Epoch 74, Batch 20, Loss: 0.020178688690066338




Epoch 74, Batch 30, Loss: 0.01784638874232769




Epoch 74, Batch 40, Loss: 0.018029192462563515




Epoch 74, Batch 50, Loss: 0.021251173689961433




Epoch 74, Batch 60, Loss: 0.0215061716735363




Epoch 74, Batch 70, Loss: 0.01930527575314045




Epoch 74, Batch 80, Loss: 0.023982161656022072




Epoch 74, Batch 90, Loss: 0.021114055067300797


Epoch:  75%|███████▌  | 75/100 [48:50<13:33, 32.52s/it]

Epoch 75, Batch 0, Loss: 0.014413241297006607




Epoch 75, Batch 10, Loss: 0.01734553463757038




Epoch 75, Batch 20, Loss: 0.013791495934128761




Epoch 75, Batch 30, Loss: 0.01984216459095478




Epoch 75, Batch 40, Loss: 0.017377162352204323




Epoch 75, Batch 50, Loss: 0.02018590085208416




Epoch 75, Batch 60, Loss: 0.019289959222078323




Epoch 75, Batch 70, Loss: 0.027896994724869728




Epoch 75, Batch 80, Loss: 0.023768717423081398




Epoch 75, Batch 90, Loss: 0.02329573966562748


Epoch:  76%|███████▌  | 76/100 [49:22<12:56, 32.34s/it]

Epoch 76, Batch 0, Loss: 0.019948942586779594




Epoch 76, Batch 10, Loss: 0.025135263800621033




Epoch 76, Batch 20, Loss: 0.023274997249245644




Epoch 76, Batch 30, Loss: 0.026541167870163918




Epoch 76, Batch 40, Loss: 0.017561541870236397




Epoch 76, Batch 50, Loss: 0.021266469731926918




Epoch 76, Batch 60, Loss: 0.016329394653439522




Epoch 76, Batch 70, Loss: 0.019244510680437088




Epoch 76, Batch 80, Loss: 0.019987870007753372




Epoch 76, Batch 90, Loss: 0.016456563025712967


Epoch:  77%|███████▋  | 77/100 [49:53<12:19, 32.15s/it]

Epoch 77, Batch 0, Loss: 0.019643442705273628




Epoch 77, Batch 10, Loss: 0.022358834743499756




Epoch 77, Batch 20, Loss: 0.021445469930768013




Epoch 77, Batch 30, Loss: 0.01799197308719158




Epoch 77, Batch 40, Loss: 0.01369906309992075




Epoch 77, Batch 50, Loss: 0.01712360419332981




Epoch 77, Batch 60, Loss: 0.01746976748108864




Epoch 77, Batch 70, Loss: 0.01936405897140503




Epoch 77, Batch 80, Loss: 0.01509336568415165




Epoch 77, Batch 90, Loss: 0.026084287092089653


Epoch:  78%|███████▊  | 78/100 [50:25<11:44, 32.04s/it]

Epoch 78, Batch 0, Loss: 0.0155340526252985




Epoch 78, Batch 10, Loss: 0.02082216925919056




Epoch 78, Batch 20, Loss: 0.019950944930315018




Epoch 78, Batch 30, Loss: 0.01938280463218689




Epoch 78, Batch 40, Loss: 0.02156923897564411




Epoch 78, Batch 50, Loss: 0.02411782741546631




Epoch 78, Batch 60, Loss: 0.019037608057260513




Epoch 78, Batch 70, Loss: 0.017913779243826866




Epoch 78, Batch 80, Loss: 0.022974520921707153




Epoch 78, Batch 90, Loss: 0.017526894807815552


Epoch:  79%|███████▉  | 79/100 [50:57<11:10, 31.91s/it]

Epoch 79, Batch 0, Loss: 0.0117054907605052




Epoch 79, Batch 10, Loss: 0.021437399089336395




Epoch 79, Batch 20, Loss: 0.014137540012598038




Epoch 79, Batch 30, Loss: 0.01588021032512188




Epoch 79, Batch 40, Loss: 0.01842491328716278




Epoch 79, Batch 50, Loss: 0.027862561866641045




Epoch 79, Batch 60, Loss: 0.02235359139740467




Epoch 79, Batch 70, Loss: 0.014454137533903122




Epoch 79, Batch 80, Loss: 0.02028241753578186




Epoch 79, Batch 90, Loss: 0.028188176453113556


Epoch:  80%|████████  | 80/100 [51:28<10:36, 31.81s/it]

Epoch 80, Batch 0, Loss: 0.022562775760889053




Epoch 80, Batch 10, Loss: 0.02630324475467205




Epoch 80, Batch 20, Loss: 0.02033062092959881




Epoch 80, Batch 30, Loss: 0.014551055617630482




Epoch 80, Batch 40, Loss: 0.024214863777160645




Epoch 80, Batch 50, Loss: 0.015820397064089775




Epoch 80, Batch 60, Loss: 0.01670762710273266




Epoch 80, Batch 70, Loss: 0.01606062240898609




Epoch 80, Batch 80, Loss: 0.025720341131091118




Epoch 80, Batch 90, Loss: 0.021842800080776215


Epoch:  81%|████████  | 81/100 [52:00<10:06, 31.92s/it]

Epoch 81, Batch 0, Loss: 0.02459568716585636




Epoch 81, Batch 10, Loss: 0.01167040690779686




Epoch 81, Batch 20, Loss: 0.024573680013418198




Epoch 81, Batch 30, Loss: 0.020331675186753273




Epoch 81, Batch 40, Loss: 0.02766868658363819




Epoch 81, Batch 50, Loss: 0.01446332037448883




Epoch 81, Batch 60, Loss: 0.017102014273405075




Epoch 81, Batch 70, Loss: 0.0190738458186388




Epoch 81, Batch 80, Loss: 0.01951473578810692




Epoch 81, Batch 90, Loss: 0.013017753139138222


Epoch:  82%|████████▏ | 82/100 [52:32<09:32, 31.82s/it]

Epoch 82, Batch 0, Loss: 0.022246187552809715




Epoch 82, Batch 10, Loss: 0.01848538964986801




Epoch 82, Batch 20, Loss: 0.018937962129712105




Epoch 82, Batch 30, Loss: 0.02098378911614418




Epoch 82, Batch 40, Loss: 0.018511846661567688




Epoch 82, Batch 50, Loss: 0.01772519387304783




Epoch 82, Batch 60, Loss: 0.02053150348365307




Epoch 82, Batch 70, Loss: 0.02278284728527069




Epoch 82, Batch 80, Loss: 0.014707159250974655




Epoch 82, Batch 90, Loss: 0.02061440609395504


Epoch:  83%|████████▎ | 83/100 [53:04<08:59, 31.75s/it]

Epoch 83, Batch 0, Loss: 0.018567431718111038




Epoch 83, Batch 10, Loss: 0.015730680897831917




Epoch 83, Batch 20, Loss: 0.02365644834935665




Epoch 83, Batch 30, Loss: 0.017476100474596024




Epoch 83, Batch 40, Loss: 0.02237625978887081




Epoch 83, Batch 50, Loss: 0.020869946107268333




Epoch 83, Batch 60, Loss: 0.013081779703497887




Epoch 83, Batch 70, Loss: 0.030001292005181313




Epoch 83, Batch 80, Loss: 0.017612874507904053




Epoch 83, Batch 90, Loss: 0.02430860698223114


Epoch:  84%|████████▍ | 84/100 [53:35<08:27, 31.69s/it]

Epoch 84, Batch 0, Loss: 0.017208196222782135




Epoch 84, Batch 10, Loss: 0.02038474753499031




Epoch 84, Batch 20, Loss: 0.018344100564718246




Epoch 84, Batch 30, Loss: 0.014917600899934769




Epoch 84, Batch 40, Loss: 0.018801195546984673




Epoch 84, Batch 50, Loss: 0.021790247410535812




Epoch 84, Batch 60, Loss: 0.02133169025182724




Epoch 84, Batch 70, Loss: 0.017902571707963943




Epoch 84, Batch 80, Loss: 0.01932438835501671




Epoch 84, Batch 90, Loss: 0.017211081460118294


Epoch:  85%|████████▌ | 85/100 [54:07<07:54, 31.66s/it]

Epoch 85, Batch 0, Loss: 0.0164700448513031




Epoch 85, Batch 10, Loss: 0.0141404764726758




Epoch 85, Batch 20, Loss: 0.01362455915659666




Epoch 85, Batch 30, Loss: 0.017005279660224915




Epoch 85, Batch 40, Loss: 0.015864050015807152




Epoch 85, Batch 50, Loss: 0.022347893565893173




Epoch 85, Batch 60, Loss: 0.014221088029444218




Epoch 85, Batch 70, Loss: 0.015412341803312302




Epoch 85, Batch 80, Loss: 0.01613757014274597




Epoch 85, Batch 90, Loss: 0.02242998778820038


Epoch:  86%|████████▌ | 86/100 [54:38<07:23, 31.65s/it]

Epoch 86, Batch 0, Loss: 0.012977611273527145




Epoch 86, Batch 10, Loss: 0.02510756254196167




Epoch 86, Batch 20, Loss: 0.030275655910372734




Epoch 86, Batch 30, Loss: 0.021650070324540138




Epoch 86, Batch 40, Loss: 0.01945088617503643




Epoch 86, Batch 50, Loss: 0.025157460942864418




Epoch 86, Batch 60, Loss: 0.02725895121693611




Epoch 86, Batch 70, Loss: 0.016865504905581474




Epoch 86, Batch 80, Loss: 0.019122261554002762




Epoch 86, Batch 90, Loss: 0.014788636937737465


Epoch:  87%|████████▋ | 87/100 [55:10<06:51, 31.64s/it]

Epoch 87, Batch 0, Loss: 0.01341139804571867




Epoch 87, Batch 10, Loss: 0.024064192548394203




Epoch 87, Batch 20, Loss: 0.017227498814463615




Epoch 87, Batch 30, Loss: 0.015863370150327682




Epoch 87, Batch 40, Loss: 0.018468698486685753




Epoch 87, Batch 50, Loss: 0.023888269439339638




Epoch 87, Batch 60, Loss: 0.01579274982213974




Epoch 87, Batch 70, Loss: 0.017289170995354652




Epoch 87, Batch 80, Loss: 0.018233729526400566




Epoch 87, Batch 90, Loss: 0.016453543677926064


Epoch:  88%|████████▊ | 88/100 [55:42<06:19, 31.62s/it]

Epoch 88, Batch 0, Loss: 0.031619757413864136




Epoch 88, Batch 10, Loss: 0.022916391491889954




Epoch 88, Batch 20, Loss: 0.029884101822972298




Epoch 88, Batch 30, Loss: 0.03371800482273102




Epoch 88, Batch 40, Loss: 0.01945585198700428




Epoch 88, Batch 50, Loss: 0.023451225832104683




Epoch 88, Batch 60, Loss: 0.0313992016017437




Epoch 88, Batch 70, Loss: 0.019285978749394417




Epoch 88, Batch 80, Loss: 0.020807674154639244




Epoch 88, Batch 90, Loss: 0.019055698066949844


Epoch:  89%|████████▉ | 89/100 [56:13<05:47, 31.61s/it]

Epoch 89, Batch 0, Loss: 0.014233840629458427




Epoch 89, Batch 10, Loss: 0.026737738400697708




Epoch 89, Batch 20, Loss: 0.018878614529967308




Epoch 89, Batch 30, Loss: 0.017374923452734947




Epoch 89, Batch 40, Loss: 0.02029307745397091




Epoch 89, Batch 50, Loss: 0.019865063950419426




Epoch 89, Batch 60, Loss: 0.017763473093509674




Epoch 89, Batch 70, Loss: 0.020302481949329376




Epoch 89, Batch 80, Loss: 0.0251943226903677




Epoch 89, Batch 90, Loss: 0.018384359776973724


Epoch:  90%|█████████ | 90/100 [56:45<05:16, 31.60s/it]

Epoch 90, Batch 0, Loss: 0.013802390545606613




Epoch 90, Batch 10, Loss: 0.01599349081516266




Epoch 90, Batch 20, Loss: 0.019275706261396408




Epoch 90, Batch 30, Loss: 0.020928964018821716




Epoch 90, Batch 40, Loss: 0.025838924571871758




Epoch 90, Batch 50, Loss: 0.03226964548230171




Epoch 90, Batch 60, Loss: 0.021981513127684593




Epoch 90, Batch 70, Loss: 0.02265898324549198




Epoch 90, Batch 80, Loss: 0.018556365743279457




Epoch 90, Batch 90, Loss: 0.025782087817788124


Epoch:  91%|█████████ | 91/100 [57:17<04:46, 31.78s/it]

Epoch 91, Batch 0, Loss: 0.025780899450182915




Epoch 91, Batch 10, Loss: 0.020216384902596474




Epoch 91, Batch 20, Loss: 0.018310746178030968




Epoch 91, Batch 30, Loss: 0.01946062222123146




Epoch 91, Batch 40, Loss: 0.017081651836633682




Epoch 91, Batch 50, Loss: 0.020409705117344856




Epoch 91, Batch 60, Loss: 0.011493881233036518




Epoch 91, Batch 70, Loss: 0.01567782275378704




Epoch 91, Batch 80, Loss: 0.010403129272162914




Epoch 91, Batch 90, Loss: 0.015695104375481606


Epoch:  92%|█████████▏| 92/100 [57:49<04:13, 31.72s/it]

Epoch 92, Batch 0, Loss: 0.016341781243681908




Epoch 92, Batch 10, Loss: 0.019334279000759125




Epoch 92, Batch 20, Loss: 0.01942739449441433




Epoch 92, Batch 30, Loss: 0.015480736270546913




Epoch 92, Batch 40, Loss: 0.015863075852394104




Epoch 92, Batch 50, Loss: 0.025573741644620895




Epoch 92, Batch 60, Loss: 0.028561998158693314




Epoch 92, Batch 70, Loss: 0.01937161386013031




Epoch 92, Batch 80, Loss: 0.01615041121840477




Epoch 92, Batch 90, Loss: 0.01929435320198536


Epoch:  93%|█████████▎| 93/100 [58:20<03:41, 31.68s/it]

Epoch 93, Batch 0, Loss: 0.019559305161237717




Epoch 93, Batch 10, Loss: 0.015339232981204987




Epoch 93, Batch 20, Loss: 0.01289753895252943




Epoch 93, Batch 30, Loss: 0.01965950056910515




Epoch 93, Batch 40, Loss: 0.01870068348944187




Epoch 93, Batch 50, Loss: 0.015418010763823986




Epoch 93, Batch 60, Loss: 0.016349878162145615




Epoch 93, Batch 70, Loss: 0.015714449808001518




Epoch 93, Batch 80, Loss: 0.01870243437588215




Epoch 93, Batch 90, Loss: 0.017333241179585457


Epoch:  94%|█████████▍| 94/100 [58:52<03:09, 31.65s/it]

Epoch 94, Batch 0, Loss: 0.017916955053806305




Epoch 94, Batch 10, Loss: 0.021779293194413185




Epoch 94, Batch 20, Loss: 0.02181701548397541




Epoch 94, Batch 30, Loss: 0.015460389666259289




Epoch 94, Batch 40, Loss: 0.016218584030866623




Epoch 94, Batch 50, Loss: 0.017796680331230164




Epoch 94, Batch 60, Loss: 0.022176066413521767




Epoch 94, Batch 70, Loss: 0.03104795701801777




Epoch 94, Batch 80, Loss: 0.0173686183989048




Epoch 94, Batch 90, Loss: 0.014188765548169613


Epoch:  95%|█████████▌| 95/100 [59:23<02:38, 31.63s/it]

Epoch 95, Batch 0, Loss: 0.019529012963175774




Epoch 95, Batch 10, Loss: 0.018904075026512146




Epoch 95, Batch 20, Loss: 0.020627152174711227




Epoch 95, Batch 30, Loss: 0.011468873359262943




Epoch 95, Batch 40, Loss: 0.01885019615292549




Epoch 95, Batch 50, Loss: 0.020704979076981544




Epoch 95, Batch 60, Loss: 0.018386442214250565




Epoch 95, Batch 70, Loss: 0.015020864084362984




Epoch 95, Batch 80, Loss: 0.0283980593085289




Epoch 95, Batch 90, Loss: 0.016373634338378906


Epoch:  96%|█████████▌| 96/100 [59:55<02:06, 31.61s/it]

Epoch 96, Batch 0, Loss: 0.021725010126829147




Epoch 96, Batch 10, Loss: 0.014761981554329395




Epoch 96, Batch 20, Loss: 0.017474334686994553




Epoch 96, Batch 30, Loss: 0.01543046347796917




Epoch 96, Batch 40, Loss: 0.01746373437345028




Epoch 96, Batch 50, Loss: 0.017784252762794495




Epoch 96, Batch 60, Loss: 0.014715330675244331




Epoch 96, Batch 70, Loss: 0.02006186917424202




Epoch 96, Batch 80, Loss: 0.019453082233667374




Epoch 96, Batch 90, Loss: 0.02518308162689209


Epoch:  97%|█████████▋| 97/100 [1:00:26<01:34, 31.59s/it]

Epoch 97, Batch 0, Loss: 0.01485319621860981




Epoch 97, Batch 10, Loss: 0.017955495044589043




Epoch 97, Batch 20, Loss: 0.020881250500679016




Epoch 97, Batch 30, Loss: 0.012684724293649197




Epoch 97, Batch 40, Loss: 0.01861889287829399




Epoch 97, Batch 50, Loss: 0.017207792028784752




Epoch 97, Batch 60, Loss: 0.014898214489221573




Epoch 97, Batch 70, Loss: 0.02040591835975647




Epoch 97, Batch 80, Loss: 0.015396018512547016




Epoch 97, Batch 90, Loss: 0.011548507027328014


Epoch:  98%|█████████▊| 98/100 [1:00:58<01:03, 31.58s/it]

Epoch 98, Batch 0, Loss: 0.013049956411123276




Epoch 98, Batch 10, Loss: 0.022355886176228523




Epoch 98, Batch 20, Loss: 0.020824601873755455




Epoch 98, Batch 30, Loss: 0.013242858462035656




Epoch 98, Batch 40, Loss: 0.014199717901647091




Epoch 98, Batch 50, Loss: 0.019433779641985893




Epoch 98, Batch 60, Loss: 0.02179691009223461




Epoch 98, Batch 70, Loss: 0.014127351343631744




Epoch 98, Batch 80, Loss: 0.02282906137406826




Epoch 98, Batch 90, Loss: 0.022843001410365105


Epoch:  99%|█████████▉| 99/100 [1:01:30<00:31, 31.65s/it]

Epoch 99, Batch 0, Loss: 0.015447878278791904




Epoch 99, Batch 10, Loss: 0.025130044668912888




Epoch 99, Batch 20, Loss: 0.023827005177736282




Epoch 99, Batch 30, Loss: 0.02008882164955139




Epoch 99, Batch 40, Loss: 0.012454360723495483




Epoch 99, Batch 50, Loss: 0.014198454096913338




Epoch 99, Batch 60, Loss: 0.015628332272171974




Epoch 99, Batch 70, Loss: 0.011964745819568634




Epoch 99, Batch 80, Loss: 0.017982855439186096




Epoch 99, Batch 90, Loss: 0.019613800570368767


Epoch: 100%|██████████| 100/100 [1:02:02<00:00, 37.22s/it]


In [30]:
#@markdown ### **Loading Pretrained Checkpoint**
#@markdown Set `load_pretrained = True` to load pretrained weights.


# ckpt_path = "bigcp1.pth"
ckpt_path = 'ckpts/model_checkpoint_90_0_0.013802390545606613.pth'

state_dict = torch.load(ckpt_path, map_location='cuda')
noise_pred_net.load_state_dict(state_dict['model_state_dict'])
imm.model = noise_pred_net 

In [35]:
#@markdown ### **Inference**

# limit enviornment interaction to 200 steps before termination
max_steps = 200
env = PushTEnv()
# use a seed >200 to avoid initial states seen in the training dataset
env.seed(100000)

# get first observation
obs, info = env.reset()

# keep a queue of last 2 steps of observations
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTStateEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)
        # normalize observation
        nobs = normalize_data(obs_seq, stats=stats['obs'])
        # device transfer
        nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)

        # infer action
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            r = imm.sample(shape=(1, pred_horizon, action_dim), steps=2, global_cond=obs_cond)

        # unnormalize action
        naction = r.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break

# print out the maximum target coverage
print('Score: ', max(rewards))

# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTStateEnv: 201it [05:04,  1.52s/it, reward=0]                            


Score:  0.0
