Skip to content

Commit

Permalink
Merge pull request #68 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Fix state_index device after clean_state
  • Loading branch information
LukasHedegaard committed Jun 16, 2023
2 parents 63fd064 + a189713 commit 7d7f495
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 12 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.

## Unpublished


## [1.2.3] - 2023-06-16

### Fixed
- Ensure state_index remains on the same device after clean_state.


## [1.2.2] - 2023-05-24

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion continual/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

__version__ = "1.2.2"
__version__ = "1.2.3"
__author__ = "Lukas Hedegaard"
__author_email__ = "lukasxhedegaard@gmail.com"
__license__ = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion continual/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def init_state(
return (state_buffer, state_index, stride_index)

def clean_state(self):
self.state_buffer = torch.tensor([])
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
self.state_index = torch.tensor(0)
self.stride_index = torch.tensor(0)

Expand Down
2 changes: 1 addition & 1 deletion continual/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def init_state(
return state_buffer, state_index

def clean_state(self):
self.state_buffer = torch.tensor([])
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
self.state_index = torch.tensor(0)

def get_state(self):
Expand Down
10 changes: 5 additions & 5 deletions continual/multihead_attention/retroactive_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def set_state(self, state: State):
) = state

def clean_state(self):
self.d_mem = torch.tensor([])
self.AV_mem = torch.tensor([])
self.Q_mem = torch.tensor([])
self.K_T_mem = torch.tensor([])
self.V_mem = torch.tensor([])
self.d_mem = torch.tensor([], device=self.d_mem.device)
self.AV_mem = torch.tensor([], device=self.AV_mem.device)
self.Q_mem = torch.tensor([], device=self.Q_mem.device)
self.K_T_mem = torch.tensor([], device=self.K_T_mem.device)
self.V_mem = torch.tensor([], device=self.V_mem.device)
self.stride_index = torch.tensor(0)

def _forward_step(
Expand Down
6 changes: 3 additions & 3 deletions continual/multihead_attention/single_output_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def set_state(self, state: State):
) = state

def clean_state(self):
self.Q_mem = torch.tensor([])
self.K_T_mem = torch.tensor([])
self.V_mem = torch.tensor([])
self.Q_mem = torch.tensor([], device=self.Q_mem.device)
self.K_T_mem = torch.tensor([], device=self.K_T_mem.device)
self.V_mem = torch.tensor([], device=self.V_mem.device)
self.stride_index = torch.tensor(0)

@property
Expand Down
2 changes: 1 addition & 1 deletion continual/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def init_state(
return state_buffer, state_index, stride_index

def clean_state(self):
self.state_buffer = torch.tensor([])
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
self.state_index = torch.tensor(0)
self.stride_index = torch.tensor(0)

Expand Down

0 comments on commit 7d7f495

Please sign in to comment.